Compare commits
39 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6a7fc20df8 | |||
| e6c9d88561 | |||
| f15057e5ed | |||
| f34a4c0c5c | |||
| 9674db6256 | |||
| ebc64d766f | |||
| 23899c703f | |||
| 352f4b1756 | |||
| 32fe80859b | |||
| a99b2480f4 | |||
| ca822806a9 | |||
| ce2051a722 | |||
| 60eafe0ddb | |||
| 0ba4ff98be | |||
| db6a105538 | |||
| 58eacecea9 | |||
| e5d50c19d8 | |||
| 3faf2fcfc2 | |||
| 422ef976f6 | |||
| d335ee76ad | |||
| b2ad470623 | |||
| 6fb97a5f4f | |||
| c3d2cc3a67 | |||
| db9af4a4f6 | |||
| 79c41de4a7 | |||
| c0e1fdfbfc | |||
| 6a6d31b605 | |||
| 0322dd714a | |||
| 6497e9103f | |||
| 2a28f13ef3 | |||
| 2f1c2a343c | |||
| 70ebbd3759 | |||
| a5ab1314f5 | |||
| 73a1733237 | |||
| 8919248560 | |||
| 1538df5544 | |||
| 622929b17e | |||
| a8da4e781b | |||
| 48d2d9675d |
@@ -0,0 +1,5 @@
|
||||
data/
|
||||
tmp/
|
||||
temp.*
|
||||
.venv/
|
||||
__pycache__/
|
||||
@@ -0,0 +1,170 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>🎮 Understand the Chinese Room Game</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
max-width: 600px;
|
||||
margin: auto;
|
||||
padding: 20px;
|
||||
}
|
||||
.game-box {
|
||||
background: #f0f8ff;
|
||||
padding: 20px;
|
||||
border-radius: 10px;
|
||||
margin: 20px 0;
|
||||
}
|
||||
.instructions {
|
||||
background: #fff3cd;
|
||||
padding: 15px;
|
||||
border-radius: 5px;
|
||||
margin: 15px 0;
|
||||
}
|
||||
button {
|
||||
background: #4caf50;
|
||||
color: white;
|
||||
padding: 10px 20px;
|
||||
border: none;
|
||||
border-radius: 5px;
|
||||
cursor: pointer;
|
||||
}
|
||||
.emoji-large {
|
||||
font-size: 2em;
|
||||
margin: 15px 0;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="game-box">
|
||||
<h1>🧩 The Understanding Game</h1>
|
||||
|
||||
<div class="instructions" id="instructions">
|
||||
<h2>📋 How to Play</h2>
|
||||
<p>
|
||||
You'll be shown <strong>emoji pairs</strong> and
|
||||
<strong>translation rules</strong>.
|
||||
</p>
|
||||
<p>
|
||||
Your job: Follow the rules <em>exactly</em> to translate
|
||||
messages
|
||||
</p>
|
||||
<p>Example Rule: 🐶 → 🐕</p>
|
||||
<p>Example Input: 🐶 = ?</p>
|
||||
<p>Correct Answer: 🐕</p>
|
||||
<button onclick="startGame()">Start Game!</button>
|
||||
</div>
|
||||
|
||||
<div id="gameScreen" style="display: none">
|
||||
<div class="emoji-large" id="currentRule"></div>
|
||||
<div id="taskDisplay"></div>
|
||||
<button onclick="submitAnswer()" id="actionButton">
|
||||
Translate Now!
|
||||
</button>
|
||||
<div id="result"></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let currentRound = 0;
|
||||
const rules = [
|
||||
{
|
||||
input: "🌧️🔥",
|
||||
output: "🚪",
|
||||
description: "Replace RAIN-FIRE with DOOR",
|
||||
},
|
||||
{
|
||||
input: "🐇🎩",
|
||||
output: "🎉",
|
||||
description: "Replace RABBIT-HAT with PARTY",
|
||||
},
|
||||
{
|
||||
input: "👁️🍯",
|
||||
output: "🐝",
|
||||
description: "Replace EYE-HONEY with BEE",
|
||||
},
|
||||
];
|
||||
|
||||
function startGame() {
|
||||
document.getElementById("instructions").style.display = "none";
|
||||
document.getElementById("gameScreen").style.display = "block";
|
||||
nextRound();
|
||||
}
|
||||
|
||||
function nextRound() {
|
||||
if (currentRound >= rules.length) return endGame();
|
||||
|
||||
document.getElementById("result").textContent = "";
|
||||
document.querySelector("#actionButton").style.display = "block";
|
||||
|
||||
document.getElementById("currentRule").textContent = `📜 RULE ${
|
||||
currentRound + 1
|
||||
}: ${rules[currentRound].description}`;
|
||||
document.getElementById("taskDisplay").innerHTML = `
|
||||
<div class="emoji-large">${rules[currentRound].input}</div>
|
||||
<p>↓ Translate using the rule above ↓</p>
|
||||
`;
|
||||
}
|
||||
|
||||
function submitAnswer() {
|
||||
const userAnswer = prompt(
|
||||
`What does ${rules[currentRound].input} become?\n(Type the emoji)`
|
||||
);
|
||||
|
||||
if (userAnswer === rules[currentRound].output) {
|
||||
document.getElementById("result").textContent =
|
||||
"✅ Correct! You followed the rule perfectly!\n";
|
||||
document.getElementById(
|
||||
"result"
|
||||
).textContent += `But did you understand what ${rules[currentRound].input} really means?`;
|
||||
} else {
|
||||
document.getElementById("result").textContent = "";
|
||||
document.getElementById("result").textContent +=
|
||||
"❌ Incorrect translation\n";
|
||||
document.getElementById(
|
||||
"result"
|
||||
).textContent += `Remember: ${rules[currentRound].description}`;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
document.querySelector("#actionButton").style.display = "none";
|
||||
|
||||
// Create a dedicated countdown element
|
||||
let count = 3;
|
||||
const countdownElem = document.createElement("p");
|
||||
countdownElem.id = "countdown";
|
||||
document.getElementById("result").appendChild(countdownElem);
|
||||
|
||||
const interval = setInterval(() => {
|
||||
countdownElem.textContent = `Next question in ${count}...`;
|
||||
count--;
|
||||
if (count < 0) {
|
||||
clearInterval(interval);
|
||||
countdownElem.remove();
|
||||
currentRound++;
|
||||
nextRound();
|
||||
}
|
||||
}, 1000);
|
||||
}
|
||||
|
||||
function endGame() {
|
||||
document.getElementById("gameScreen").innerHTML = `
|
||||
<h2>🎉 Game Over!</h2>
|
||||
<p>You successfully followed all the rules!</p>
|
||||
<div class="instructions">
|
||||
<h3>The actual meanings you were translating:</h3>
|
||||
<ul>
|
||||
<li>🌧️🔥 = "Steam"</li>
|
||||
<li>🐇🎩 = "Magic trick"</li>
|
||||
<li>👁️🍯 = "Sweet look"</li>
|
||||
</ul>
|
||||
<p>You followed syntax perfectly...</p>
|
||||
<p><strong>...without ever understanding semantics!</strong></p>
|
||||
<p>This is basically how AI language models work</p>
|
||||
</div>
|
||||
`;
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -12,13 +12,29 @@
|
||||
|
||||
## Problem II
|
||||
|
||||
### Example 1: Boiling Water
|
||||
- **Forward Problem**: Given a specific amount of water, heat source, and initial temperature, find how long it takes the water to boil.
|
||||
- **Inverse Problem**: Given the time it took for the water to boil, determine the starting temperature (or the heat source).
|
||||
- **Difficulty**: The forward problem is easier as it involves direct calculations. The inverse problem is harder as it requires working backward to determine the unknowns.
|
||||
|
||||
### Example 2: Hot Air Balloon
|
||||
- **Forward Problem**: Given the launch location, wind conditions, and balloon details, predict where the balloon will land.
|
||||
- **Inverse Problem**: Given the landing location and flight path, determine the launch location and wind conditions.
|
||||
- **Difficulty**: The forward problem is easier because it simulates the balloon's movement based on the given information. The inverse problem is harder because it requires determining the initial conditions (like launch location and wind) from the result.
|
||||
|
||||
### Example 3: Wildfire
|
||||
- **Forward Problem**: Given the starting point of the fire, weather, and terrain, predict how the fire will spread.
|
||||
- **Inverse Problem**: Given where the fire spread to, determine where and when it started.
|
||||
- **Difficulty**: The forward problem is challenging as it involves predicting fire behavior under various factors. The inverse problem is even harder because it requires deducing the origin of the fire from its spread, which can be ambiguous and unclear.
|
||||
|
||||
|
||||
|
||||
## Problem III (Git and GitHub)
|
||||
### Part I
|
||||
see https://github.com/ION606/COGMOD-HWI
|
||||
|
||||
### Part II
|
||||
see COMMIT_HASH_HERE
|
||||
see https://github.com/ION606/COGMOD-HWI/commit/260b0f4c3b9430a9d0ba2e29fd50d3f9c640c498
|
||||
|
||||
### Part III
|
||||
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
import numpy as np
|
||||
|
||||
# SIMULATION CALCULATIONS
|
||||
def simulate_culprit(N, pSuperman):
|
||||
return np.random.rand(N) < pSuperman
|
||||
|
||||
def simulate_crumbs(N, supermanProb, batmanProb, culprit):
|
||||
randomDraw = np.random.rand(N, 3)
|
||||
supermanCrumbs = (randomDraw < supermanProb)
|
||||
batmanCrumbs = (randomDraw < batmanProb)
|
||||
return np.where(culprit[:, None], supermanCrumbs, batmanCrumbs)
|
||||
|
||||
def combination(crumbResults, culprit):
|
||||
combinations = {}
|
||||
for couch in [False, True]:
|
||||
for kitchen in [False, True]:
|
||||
for gym in [False, True]:
|
||||
for culprit_label, culprit_val in [("Superman", True), ("Batman", False)]:
|
||||
mask = (crumbResults[:, 0] == couch) & (crumbResults[:, 1] == kitchen) & (crumbResults[:, 2] == gym) & (culprit == culprit_val)
|
||||
combinations[(couch, kitchen, gym, culprit_label)] = np.sum(mask)
|
||||
return combinations
|
||||
|
||||
def print_probabilities(combinations, N):
|
||||
print("Couch:\t Kitchen: Gym:\t Culprit:")
|
||||
for k, v in combinations.items():
|
||||
couch, kitchen, gym, culprit_label = k
|
||||
#formatting
|
||||
couch_str = 'True ' if couch else 'False'
|
||||
kitchen_str = 'True ' if kitchen else 'False'
|
||||
gym_str = 'True ' if gym else 'False'
|
||||
if culprit_label == 'Batman': culprit_label = 'Batman '
|
||||
|
||||
print(f"{couch_str}\t {kitchen_str}\t {gym_str} {culprit_label}: {(v / N) * 100:.2f}%")
|
||||
|
||||
# ANALYTIC CALCULATIONS
|
||||
def analytic_probabilities(pSuperman, pBatman, supermanProb, batmanProb):
|
||||
combinations = {}
|
||||
for couch in [False, True]:
|
||||
for kitchen in [False, True]:
|
||||
for gym in [False, True]:
|
||||
#Superman
|
||||
prob_superman = pSuperman
|
||||
prob_superman *= supermanProb[0] if couch else (1 - supermanProb[0])
|
||||
prob_superman *= supermanProb[1] if kitchen else (1 - supermanProb[1])
|
||||
prob_superman *= supermanProb[2] if gym else (1 - supermanProb[2])
|
||||
combinations[(couch, kitchen, gym, "Superman")] = prob_superman
|
||||
|
||||
#Batman
|
||||
prob_batman = pBatman
|
||||
prob_batman *= batmanProb[0] if couch else (1 - batmanProb[0])
|
||||
prob_batman *= batmanProb[1] if kitchen else (1 - batmanProb[1])
|
||||
prob_batman *= batmanProb[2] if gym else (1 - batmanProb[2])
|
||||
combinations[(couch, kitchen, gym, "Batman")] = prob_batman
|
||||
|
||||
return combinations
|
||||
|
||||
def print_analytic_probabilities(combinations):
|
||||
print("Couch:\t Kitchen: Gym:\t Culprit:")
|
||||
for k, v in combinations.items():
|
||||
couch, kitchen, gym, culprit_label = k
|
||||
|
||||
#formatting
|
||||
couch_str = 'True ' if couch else 'False'
|
||||
kitchen_str = 'True ' if kitchen else 'False'
|
||||
gym_str = 'True ' if gym else 'False'
|
||||
if culprit_label == 'Batman': culprit_label = 'Batman '
|
||||
print(f"{couch_str}\t {kitchen_str}\t {gym_str} {culprit_label}: {v * 100:.2f}%")
|
||||
@@ -0,0 +1,41 @@
|
||||
from calculateHW2P5 import simulate_culprit, simulate_crumbs, combination, print_probabilities, analytic_probabilities, print_analytic_probabilities
|
||||
import numpy as np
|
||||
|
||||
if __name__ == "__main__":
|
||||
#N = 100000 #used while testing
|
||||
NSize = [1000, 10000, 100000]
|
||||
#Priors for each suspect
|
||||
pSuperman = 0.5
|
||||
pBatman = 0.5
|
||||
|
||||
#Likelihoods of crumbs on each location
|
||||
supermanProb = np.array([0.3, 0.7, 0.2])
|
||||
batmanProb = np.array([0.4, 0.6, 0.3])
|
||||
|
||||
#Simulate
|
||||
'''
|
||||
culprit = simulate_culprit(N, pSuperman)
|
||||
crumbResults = simulate_crumbs(N, supermanProb, batmanProb, culprit)
|
||||
|
||||
print("Simulation:")
|
||||
print_probabilities(combination(crumbResults, culprit), N)
|
||||
'''
|
||||
print("Simulation:")
|
||||
for N in NSize:
|
||||
print(f"N = {N}")
|
||||
# Simulate culprit and crumbs
|
||||
culprit = simulate_culprit(N, pSuperman)
|
||||
crumbResults = simulate_crumbs(N, supermanProb, batmanProb, culprit)
|
||||
|
||||
# Count combinations and print probabilities
|
||||
print("Simulated Probabilities:")
|
||||
print_probabilities(combination(crumbResults, culprit), N)
|
||||
print("\n")
|
||||
|
||||
#Analytic
|
||||
print("Analytic:")
|
||||
print_analytic_probabilities(analytic_probabilities(pSuperman, pBatman, supermanProb, batmanProb))
|
||||
|
||||
'''
|
||||
As the value of N increases, the simulated probabilities get closer to the analytic probabilities.
|
||||
'''
|
||||
|
After Width: | Height: | Size: 16 KiB |
|
After Width: | Height: | Size: 25 KiB |
|
After Width: | Height: | Size: 27 KiB |
@@ -0,0 +1,28 @@
|
||||
data {
|
||||
int<lower=0> N; // Number of training samples
|
||||
int<lower=0> M; // Number of predictors (3: bmi, age, children)
|
||||
matrix[N, M] X; // Training predictors
|
||||
vector[N] y; // Training target (charges)
|
||||
int<lower=0> N_test; // Number of test samples
|
||||
matrix[N_test, M] X_test; // Test predictors
|
||||
}
|
||||
parameters {
|
||||
real alpha; // Intercept
|
||||
vector[M] beta; // Regression coefficients
|
||||
real<lower=0> sigma; // Noise term
|
||||
}
|
||||
model {
|
||||
// Priors
|
||||
sigma ~ inv_gamma(2, 2);
|
||||
alpha ~ normal(0, 10);
|
||||
beta ~ normal(0, 1);
|
||||
|
||||
// Likelihood
|
||||
y ~ normal(alpha + X * beta, sigma);
|
||||
}
|
||||
generated quantities {
|
||||
vector[N_test] y_pred;
|
||||
for (i in 1:N_test) {
|
||||
y_pred[i] = normal_rng(alpha + X_test[i] * beta, sigma);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
build/
|
||||
.venv/
|
||||
|
After Width: | Height: | Size: 30 KiB |
@@ -0,0 +1,50 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def simulate_ddm(v, a=1.0, beta=0.5, tau=0.3, sigma=1.0, dt=0.001, max_steps=3000):
|
||||
X = beta * a # start position
|
||||
t = 0.0
|
||||
for _ in range(max_steps):
|
||||
dW = np.random.normal(0, np.sqrt(dt))
|
||||
dX = v * dt + sigma * dW
|
||||
X += dX
|
||||
t += dt
|
||||
if X >= a:
|
||||
return t + tau, 1 # upper bound hit
|
||||
elif X <= 0:
|
||||
return t + tau, 0 # lower bound hit
|
||||
return max_steps * dt + tau, None # Timeout (optional)
|
||||
|
||||
|
||||
# terrible params (upped in part 2)
|
||||
vs = np.linspace(0.5, 1.5, 25) # drift rates for test
|
||||
n_trials = 2000
|
||||
|
||||
# store
|
||||
upper_means, lower_means = [], []
|
||||
|
||||
for v in vs:
|
||||
upper_rts, lower_rts = [], []
|
||||
for _ in range(n_trials):
|
||||
rt, choice = simulate_ddm(v)
|
||||
if choice == 1:
|
||||
upper_rts.append(rt)
|
||||
elif choice == 0:
|
||||
lower_rts.append(rt)
|
||||
# means (ignore cases where no hits)
|
||||
upper_means.append(np.mean(upper_rts) if upper_rts else np.nan)
|
||||
lower_means.append(np.mean(lower_rts) if lower_rts else np.nan)
|
||||
|
||||
# plotting yay
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(vs, upper_means, 'o-', label='Upper Boundary Mean RT')
|
||||
plt.plot(vs, lower_means, 's-', label='Lower Boundary Mean RT')
|
||||
plt.plot(vs, np.array(upper_means) - np.array(lower_means),
|
||||
'd-', label='Mean Difference')
|
||||
plt.xlabel('Drift Rate (v)')
|
||||
plt.ylabel('Response Time (s)')
|
||||
plt.title('Effect of Drift Rate on RT Distributions')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.savefig('part1.png')
|
||||
@@ -0,0 +1,104 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from multiprocessing import Pool, cpu_count
|
||||
from functools import partial
|
||||
|
||||
|
||||
def sim_ddm(v=1.0, a=1.0, beta=0.5, tau=0.3, sigma=1.0, dt=0.001, max_steps=3000):
|
||||
X = beta * a # start
|
||||
t = 0.0
|
||||
for _ in range(max_steps):
|
||||
dW = np.random.normal(0, np.sqrt(dt))
|
||||
dX = v * dt + sigma * dW
|
||||
X += dX
|
||||
t += dt
|
||||
if X >= a:
|
||||
return t + tau, 1 # upper bound hit
|
||||
elif X <= 0:
|
||||
return t + tau, 0 # lower bound hit
|
||||
return max_steps * dt + tau, None # timeout (which I ignored)
|
||||
|
||||
|
||||
def sim_param(param_name, param_value, n_trials=200000):
|
||||
default_params = {'v': 1.0, 'a': 1.0,
|
||||
'beta': 0.5, 'tau': 0.3, 'sigma': 1.0}
|
||||
params = default_params.copy()
|
||||
params[param_name] = param_value
|
||||
upper_rts, lower_rts = [], []
|
||||
for _ in range(n_trials):
|
||||
rt, choice = sim_ddm(**params)
|
||||
if choice == 1:
|
||||
upper_rts.append(rt)
|
||||
elif choice == 0:
|
||||
lower_rts.append(rt)
|
||||
return (upper_rts, lower_rts) # Return all RTs
|
||||
|
||||
|
||||
# deepseek-r1 wrote this to help parallelize my code (because for loops aren't cool when they're frying my laptop)
|
||||
def parallel_sim_param(param_name, param_values, n_trials):
|
||||
worker = partial(sim_param,
|
||||
param_name, n_trials=n_trials)
|
||||
with Pool(processes=cpu_count()) as pool:
|
||||
results = pool.map(worker, param_values)
|
||||
return results
|
||||
|
||||
|
||||
parameters = {
|
||||
'v': np.linspace(0.5, 1.5, 25),
|
||||
'a': np.linspace(0.5, 2.0, 25),
|
||||
'beta': np.linspace(0.3, 0.7, 25),
|
||||
'tau': np.linspace(0.1, 0.5, 25),
|
||||
}
|
||||
|
||||
fig, axes = plt.subplots(4, 2, figsize=(15, 20)) # should this be (15, 15)?
|
||||
axes = axes.flatten()
|
||||
|
||||
for i, (param, values) in enumerate(parameters.items()):
|
||||
results = parallel_sim_param(param, values, n_trials=200000)
|
||||
|
||||
# no bootstrapping
|
||||
means_upper, means_lower = [], []
|
||||
stdev_upper, stdev_lower = [], []
|
||||
|
||||
for upper_rts, lower_rts in results:
|
||||
mu_upper = np.mean(upper_rts) if upper_rts else np.nan
|
||||
mu_lower = np.mean(lower_rts) if lower_rts else np.nan
|
||||
std_upper = np.std(upper_rts) if upper_rts else np.nan
|
||||
std_lower = np.std(lower_rts) if lower_rts else np.nan
|
||||
|
||||
means_upper.append(mu_upper)
|
||||
means_lower.append(mu_lower)
|
||||
stdev_upper.append(std_upper)
|
||||
stdev_lower.append(std_lower)
|
||||
|
||||
# means
|
||||
ax_mean = axes[2 * i]
|
||||
ax_mean.plot(values, means_upper, 'o-', label='Upper Boundary Mean RT')
|
||||
ax_mean.plot(values, means_lower, 's-', label='Lower Boundary Mean RT')
|
||||
ax_mean.plot(values, np.subtract(means_upper, means_lower),
|
||||
'd-', label='Difference', color='red')
|
||||
ax_mean.set_xlabel(param)
|
||||
ax_mean.set_ylabel('Response Time (s)')
|
||||
ax_mean.set_title(f'Effect of {param} on RT Means')
|
||||
ax_mean.legend()
|
||||
ax_mean.grid(True)
|
||||
|
||||
# STDDEV
|
||||
ax_std = axes[2 * i + 1]
|
||||
ax_std.plot(values, stdev_upper, 'o-', label='Upper Boundary Std RT')
|
||||
ax_std.plot(values, stdev_lower, 's-', label='Lower Boundary Std RT')
|
||||
ax_std.set_xlabel(param)
|
||||
ax_std.set_ylabel('Standard Deviation (s)')
|
||||
ax_std.set_title(f'Effect of {param} on RT Std Devs')
|
||||
ax_std.legend()
|
||||
ax_std.grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('part2.png')
|
||||
|
||||
# DEBUGGING
|
||||
print(f"\nVARYING {param.upper()}:\n")
|
||||
print(f"Means (Upper): {np.round(means_upper, 5)}")
|
||||
print(f"Means (Lower): {np.round(means_lower, 5)}")
|
||||
print(f"Std (Upper): {np.round(stdev_upper, 5)}")
|
||||
print(f"Std (Lower): {np.round(stdev_lower, 5)}")
|
||||
@@ -0,0 +1,421 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 52,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import stan\n",
|
||||
"import arviz as az\n",
|
||||
"\n",
|
||||
"# stupid stan problems\n",
|
||||
"import nest_asyncio\n",
|
||||
"nest_asyncio.apply()\n",
|
||||
"\n",
|
||||
"# true param\n",
|
||||
"alpha_true = 2.3,\n",
|
||||
"beta_true = 4.0,\n",
|
||||
"sigma_true = 2.0,\n",
|
||||
"N = 100\n",
|
||||
"\n",
|
||||
"# simulation\n",
|
||||
"np.random.seed(42)\n",
|
||||
"x = np.random.normal(size=N)\n",
|
||||
"y = alpha_true + beta_true * x + sigma_true * np.random.normal(size=N)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 53,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"stanCode = \"\"\"\n",
|
||||
"data {\n",
|
||||
" int<lower=0> N;\n",
|
||||
" vector[N] x;\n",
|
||||
" vector[N] y;\n",
|
||||
"}\n",
|
||||
"parameters {\n",
|
||||
" real alpha;\n",
|
||||
" real beta;\n",
|
||||
" real<lower=0> sigma_sq;\n",
|
||||
"}\n",
|
||||
"transformed parameters {\n",
|
||||
" real<lower=0> sigma = sqrt(sigma_sq);\n",
|
||||
"}\n",
|
||||
"model {\n",
|
||||
" sigma_sq ~ inv_gamma(1, 1); // prior on variance\n",
|
||||
" alpha ~ normal(0, 10);\n",
|
||||
" beta ~ normal(0, 10);\n",
|
||||
" y ~ normal(alpha + beta * x, sigma); // likelihood\n",
|
||||
"}\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 54,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Building...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Building: found in cache, done.Sampling: 0%\n",
|
||||
"Sampling: 25% (3000/12000)\n",
|
||||
"Sampling: 50% (6000/12000)\n",
|
||||
"Sampling: 75% (9000/12000)\n",
|
||||
"Sampling: 100% (12000/12000)\n",
|
||||
"Sampling: 100% (12000/12000), done.\n",
|
||||
"Messages received during sampling:\n",
|
||||
" Gradient evaluation took 1.7e-05 seconds\n",
|
||||
" 1000 transitions using 10 leapfrog steps per transition would take 0.17 seconds.\n",
|
||||
" Adjust your expectations accordingly!\n",
|
||||
" Gradient evaluation took 2.7e-05 seconds\n",
|
||||
" 1000 transitions using 10 leapfrog steps per transition would take 0.27 seconds.\n",
|
||||
" Adjust your expectations accordingly!\n",
|
||||
" Informational Message: The current Metropolis proposal is about to be rejected because of the following issue:\n",
|
||||
" Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in '/tmp/httpstan__2qigylb/model_74j73ceb.stan', line 19, column 2 to column 38)\n",
|
||||
" If this warning occurs sporadically, such as for highly constrained variable types like covariance matrices, then the sampler is fine,\n",
|
||||
" but if this warning occurs often then your model may be either severely ill-conditioned or misspecified.\n",
|
||||
" Gradient evaluation took 2e-05 seconds\n",
|
||||
" 1000 transitions using 10 leapfrog steps per transition would take 0.2 seconds.\n",
|
||||
" Adjust your expectations accordingly!\n",
|
||||
" Gradient evaluation took 1e-05 seconds\n",
|
||||
" 1000 transitions using 10 leapfrog steps per transition would take 0.1 seconds.\n",
|
||||
" Adjust your expectations accordingly!\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>mean</th>\n",
|
||||
" <th>sd</th>\n",
|
||||
" <th>hdi_3%</th>\n",
|
||||
" <th>hdi_97%</th>\n",
|
||||
" <th>mcse_mean</th>\n",
|
||||
" <th>mcse_sd</th>\n",
|
||||
" <th>ess_bulk</th>\n",
|
||||
" <th>ess_tail</th>\n",
|
||||
" <th>r_hat</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>alpha</th>\n",
|
||||
" <td>2.317</td>\n",
|
||||
" <td>0.192</td>\n",
|
||||
" <td>1.959</td>\n",
|
||||
" <td>2.683</td>\n",
|
||||
" <td>0.002</td>\n",
|
||||
" <td>0.002</td>\n",
|
||||
" <td>6909.0</td>\n",
|
||||
" <td>5804.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>beta</th>\n",
|
||||
" <td>3.713</td>\n",
|
||||
" <td>0.208</td>\n",
|
||||
" <td>3.327</td>\n",
|
||||
" <td>4.117</td>\n",
|
||||
" <td>0.002</td>\n",
|
||||
" <td>0.002</td>\n",
|
||||
" <td>7805.0</td>\n",
|
||||
" <td>5904.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>sigma_sq</th>\n",
|
||||
" <td>3.615</td>\n",
|
||||
" <td>0.511</td>\n",
|
||||
" <td>2.716</td>\n",
|
||||
" <td>4.584</td>\n",
|
||||
" <td>0.006</td>\n",
|
||||
" <td>0.006</td>\n",
|
||||
" <td>7166.0</td>\n",
|
||||
" <td>5819.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>sigma</th>\n",
|
||||
" <td>1.897</td>\n",
|
||||
" <td>0.133</td>\n",
|
||||
" <td>1.648</td>\n",
|
||||
" <td>2.141</td>\n",
|
||||
" <td>0.002</td>\n",
|
||||
" <td>0.001</td>\n",
|
||||
" <td>7166.0</td>\n",
|
||||
" <td>5819.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n",
|
||||
"alpha 2.317 0.192 1.959 2.683 0.002 0.002 6909.0 \n",
|
||||
"beta 3.713 0.208 3.327 4.117 0.002 0.002 7805.0 \n",
|
||||
"sigma_sq 3.615 0.511 2.716 4.584 0.006 0.006 7166.0 \n",
|
||||
"sigma 1.897 0.133 1.648 2.141 0.002 0.001 7166.0 \n",
|
||||
"\n",
|
||||
" ess_tail r_hat \n",
|
||||
"alpha 5804.0 1.0 \n",
|
||||
"beta 5904.0 1.0 \n",
|
||||
"sigma_sq 5819.0 1.0 \n",
|
||||
"sigma 5819.0 1.0 "
|
||||
]
|
||||
},
|
||||
"execution_count": 54,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Define data first\n",
|
||||
"data = {\"N\": N, \"x\": x, \"y\": y}\n",
|
||||
"\n",
|
||||
"# Build the model with data\n",
|
||||
"model = stan.build(stanCode, data=data)\n",
|
||||
"\n",
|
||||
"# Sample\n",
|
||||
"fit = model.sample(num_chains=4, num_samples=2000)\n",
|
||||
"\n",
|
||||
"az.summary(az.from_pystan(fit))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Step 4: Analyze Results for N=100\n",
|
||||
"\n",
|
||||
"Posterior summaries should be close to the true values:\n",
|
||||
"\n",
|
||||
"- **α**: approximately 2.3\n",
|
||||
"- **β**: approximately 4.0\n",
|
||||
"- **σ**: approximately 2.0\n",
|
||||
"\n",
|
||||
"Also compute the 95% credible intervals."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Step 5: Repeat with N=1000\n",
|
||||
"\n",
|
||||
"Increase the sample size and rerun the simulation and model fitting."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 55,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Building...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Building: found in cache, done.Sampling: 0%\n",
|
||||
"Sampling: 25% (3000/12000)\n",
|
||||
"Sampling: 50% (6000/12000)\n",
|
||||
"Sampling: 75% (9000/12000)\n",
|
||||
"Sampling: 100% (12000/12000)\n",
|
||||
"Sampling: 100% (12000/12000), done.\n",
|
||||
"Messages received during sampling:\n",
|
||||
" Gradient evaluation took 0.000146 seconds\n",
|
||||
" 1000 transitions using 10 leapfrog steps per transition would take 1.46 seconds.\n",
|
||||
" Adjust your expectations accordingly!\n",
|
||||
" Gradient evaluation took 0.000126 seconds\n",
|
||||
" 1000 transitions using 10 leapfrog steps per transition would take 1.26 seconds.\n",
|
||||
" Adjust your expectations accordingly!\n",
|
||||
" Informational Message: The current Metropolis proposal is about to be rejected because of the following issue:\n",
|
||||
" Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in '/tmp/httpstan__2qigylb/model_74j73ceb.stan', line 19, column 2 to column 38)\n",
|
||||
" If this warning occurs sporadically, such as for highly constrained variable types like covariance matrices, then the sampler is fine,\n",
|
||||
" but if this warning occurs often then your model may be either severely ill-conditioned or misspecified.\n",
|
||||
" Informational Message: The current Metropolis proposal is about to be rejected because of the following issue:\n",
|
||||
" Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in '/tmp/httpstan__2qigylb/model_74j73ceb.stan', line 19, column 2 to column 38)\n",
|
||||
" If this warning occurs sporadically, such as for highly constrained variable types like covariance matrices, then the sampler is fine,\n",
|
||||
" but if this warning occurs often then your model may be either severely ill-conditioned or misspecified.\n",
|
||||
" Gradient evaluation took 0.000123 seconds\n",
|
||||
" 1000 transitions using 10 leapfrog steps per transition would take 1.23 seconds.\n",
|
||||
" Adjust your expectations accordingly!\n",
|
||||
" Gradient evaluation took 0.000135 seconds\n",
|
||||
" 1000 transitions using 10 leapfrog steps per transition would take 1.35 seconds.\n",
|
||||
" Adjust your expectations accordingly!\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>mean</th>\n",
|
||||
" <th>sd</th>\n",
|
||||
" <th>hdi_3%</th>\n",
|
||||
" <th>hdi_97%</th>\n",
|
||||
" <th>mcse_mean</th>\n",
|
||||
" <th>mcse_sd</th>\n",
|
||||
" <th>ess_bulk</th>\n",
|
||||
" <th>ess_tail</th>\n",
|
||||
" <th>r_hat</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>alpha</th>\n",
|
||||
" <td>2.366</td>\n",
|
||||
" <td>0.062</td>\n",
|
||||
" <td>2.253</td>\n",
|
||||
" <td>2.484</td>\n",
|
||||
" <td>0.001</td>\n",
|
||||
" <td>0.001</td>\n",
|
||||
" <td>7563.0</td>\n",
|
||||
" <td>5508.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>beta</th>\n",
|
||||
" <td>3.929</td>\n",
|
||||
" <td>0.063</td>\n",
|
||||
" <td>3.814</td>\n",
|
||||
" <td>4.048</td>\n",
|
||||
" <td>0.001</td>\n",
|
||||
" <td>0.001</td>\n",
|
||||
" <td>8352.0</td>\n",
|
||||
" <td>5934.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>sigma_sq</th>\n",
|
||||
" <td>3.895</td>\n",
|
||||
" <td>0.174</td>\n",
|
||||
" <td>3.588</td>\n",
|
||||
" <td>4.236</td>\n",
|
||||
" <td>0.002</td>\n",
|
||||
" <td>0.002</td>\n",
|
||||
" <td>8354.0</td>\n",
|
||||
" <td>6044.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>sigma</th>\n",
|
||||
" <td>1.973</td>\n",
|
||||
" <td>0.044</td>\n",
|
||||
" <td>1.894</td>\n",
|
||||
" <td>2.058</td>\n",
|
||||
" <td>0.000</td>\n",
|
||||
" <td>0.000</td>\n",
|
||||
" <td>8354.0</td>\n",
|
||||
" <td>6044.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n",
|
||||
"alpha 2.366 0.062 2.253 2.484 0.001 0.001 7563.0 \n",
|
||||
"beta 3.929 0.063 3.814 4.048 0.001 0.001 8352.0 \n",
|
||||
"sigma_sq 3.895 0.174 3.588 4.236 0.002 0.002 8354.0 \n",
|
||||
"sigma 1.973 0.044 1.894 2.058 0.000 0.000 8354.0 \n",
|
||||
"\n",
|
||||
" ess_tail r_hat \n",
|
||||
"alpha 5508.0 1.0 \n",
|
||||
"beta 5934.0 1.0 \n",
|
||||
"sigma_sq 6044.0 1.0 \n",
|
||||
"sigma 6044.0 1.0 "
|
||||
]
|
||||
},
|
||||
"execution_count": 55,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"N_large = 1000;\n",
|
||||
"x_large = np.random.normal(size=N_large);\n",
|
||||
"y_large = alpha_true + beta_true * x_large + sigma_true * np.random.normal(size=N_large);\n",
|
||||
"\n",
|
||||
"# create new data dictionary\n",
|
||||
"data_large = {\"N\": N_large, \"x\": x_large, \"y\": y_large};\n",
|
||||
"model_large = stan.build(stanCode, data=data_large)\n",
|
||||
"\n",
|
||||
"# fit the model again\n",
|
||||
"fit_large = model_large.sample(num_chains=4, num_samples=2000);\n",
|
||||
"\n",
|
||||
"# check diagnostics for larger data\n",
|
||||
"az.summary(az.from_pystan(fit_large))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.x"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
To solve Problem 5, follow these steps:
|
||||
|
||||
### Step 1: Simulate Data
|
||||
|
||||
### Step 2: Stan Model Code
|
||||
Write the Stan model (`bayesian_regression.stan`):
|
||||
```stan
|
||||
data {
|
||||
int<lower=0> N;
|
||||
vector[N] x;
|
||||
vector[N] y;
|
||||
}
|
||||
parameters {
|
||||
real alpha;
|
||||
real beta;
|
||||
real<lower=0> sigma_sq;
|
||||
}
|
||||
transformed parameters {
|
||||
real<lower=0> sigma = sqrt(sigma_sq);
|
||||
}
|
||||
model {
|
||||
sigma_sq ~ inv_gamma(1, 1); // Prior on variance
|
||||
alpha ~ normal(0, 10);
|
||||
beta ~ normal(0, 10);
|
||||
y ~ normal(alpha + beta * x, sigma); // Likelihood
|
||||
}
|
||||
```
|
||||
|
||||
### Step 3: Fit the Model and Check Diagnostics
|
||||
Use `pystan` or `cmdstanpy` to run the model. Check Rhat (≈1) and ESS (sufficiently large). For example:
|
||||
```python
|
||||
import cmdstanpy
|
||||
|
||||
model = cmdstanpy.CmdStanModel(stan_file="bayesian_regression.stan")
|
||||
data = {"N": N, "x": x, "y": y}
|
||||
fit = model.sample(data=data, chains=4, iter_sampling=2000)
|
||||
|
||||
# Check diagnostics
|
||||
print(fit.diagnose())
|
||||
```
|
||||
|
||||
### Step 4: Analyze Results for N=100
|
||||
Posterior summaries:
|
||||
- **Posterior means** should be close to true values (α=2.3, β=4.0, σ=2.0).
|
||||
- **Uncertainty**: Compute 95% credible intervals. Example output:
|
||||
- α: 2.1 ± 0.4 (1.7 to 2.5)
|
||||
- β: 3.8 ± 0.5 (3.3 to 4.3)
|
||||
- σ: 1.9 ± 0.2 (1.7 to 2.1)
|
||||
|
||||
### Step 5: Repeat with N=1000
|
||||
Increase sample size and rerun:
|
||||
```python
|
||||
N_large = 1000
|
||||
x_large = np.random.normal(size=N_large)
|
||||
y_large = alpha_true + beta_true * x_large + sigma_true * np.random.normal(size=N_large)
|
||||
```
|
||||
Fit the model again. Results will show:
|
||||
- **Tighter credible intervals** (e.g., β: 3.95 ± 0.1).
|
||||
- Reduced posterior variance, indicating higher precision.
|
||||
|
||||
### Key Observations:
|
||||
1. **Accuracy**: Posterior means align closely with true parameters.
|
||||
2. **Uncertainty**: Credible intervals narrow as \(N\) increases, reflecting reduced uncertainty.
|
||||
3. **Diagnostics**: Ensure Rhat ≈1 and sufficient ESS for reliable inferences.
|
||||
|
||||
**Visualization**: Plot prior vs. posterior histograms for parameters (using tools like `arviz` or `seaborn`), showing posterior concentration around true values, especially for \(N=1000\).
|
||||
|
||||
---
|
||||
|
||||
**Answer for LMS Submission**
|
||||
Implement the steps above, ensuring your write-up includes code snippets, diagnostic results, and graphical comparisons. Highlight the reduction in posterior variance when increasing \(N\), demonstrating the influence of data quantity on Bayesian inference.
|
||||
@@ -0,0 +1,46 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
# H Y P E R P A R A M E T E R S
|
||||
mu_prior = 0 # prior mean
|
||||
sigma2_prior = 2 # prior variance (omega_0^2)
|
||||
sigma2_likelihood = 1 # likelihood variance (omega^2)
|
||||
n_samples = 1000000 # number of Monte Carlo samples
|
||||
|
||||
# simulate θ ~ N(mu_0, omega_0^2) and y ~ N(θ, (omega)^2)
|
||||
theta = np.random.normal(mu_prior, np.sqrt(sigma2_prior), n_samples)
|
||||
y = np.random.normal(theta, np.sqrt(sigma2_likelihood))
|
||||
|
||||
# posterior params for each y
|
||||
sigma2_posterior = 1 / (1 / sigma2_prior + 1 / sigma2_likelihood)
|
||||
mu_posterior = (mu_prior / sigma2_prior + y / sigma2_likelihood) * \
|
||||
sigma2_posterior # posterior mean
|
||||
|
||||
# E[Var[θ|y]]
|
||||
expected_posterior_var = sigma2_posterior
|
||||
var_posterior_mean = np.var(mu_posterior) # var[𝔼[θ|y]]
|
||||
prior_var = sigma2_prior # var[θ]
|
||||
|
||||
# verify identity
|
||||
sum_terms = expected_posterior_var + var_posterior_mean
|
||||
|
||||
print(f"Prior Variance (Var[θ]): {prior_var:.4f}")
|
||||
print(
|
||||
f"Expected Posterior Variance (𝔼[Var[θ|y]]): {expected_posterior_var:.4f}")
|
||||
print(f"Variance of Posterior Mean (Var[𝔼[θ|y]]): {var_posterior_mean:.4f}")
|
||||
print(f"Sum of Terms: {sum_terms:.4f}")
|
||||
print(f"Identity Holds: {np.isclose(prior_var, sum_terms, atol=1e-3)}")
|
||||
|
||||
# Plot posterior means and variances
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.hist(mu_posterior, bins=50, density=True,
|
||||
alpha=0.6, label="Posterior Means")
|
||||
plt.axvline(mu_prior, color='r', linestyle='--', label="Prior Mean")
|
||||
plt.xlabel("Posterior Mean (𝔼[θ|y])")
|
||||
plt.ylabel("Density")
|
||||
plt.title("Distribution of Posterior Means vs. Prior")
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
# plt.show()
|
||||
plt.savefig('part3.png')
|
||||
|
After Width: | Height: | Size: 30 KiB |
|
After Width: | Height: | Size: 20 KiB |
@@ -0,0 +1,76 @@
|
||||
As a culinary data scientist, you investigate how cooking time (\(x\)) affects the length of "massive ramen noodles" (\(y\)). Using Bayesian linear regression, you model the relationship to quantify expansion rates and uncertainty.
|
||||
|
||||
\subsection*{Methods}
|
||||
\subsubsection*{Model Specification}
|
||||
The regression model is:
|
||||
\[
|
||||
y_n = \alpha + \beta x_n + \epsilon_n, \quad \epsilon_n \sim \mathcal{N}(0, \sigma^2)
|
||||
\]
|
||||
\begin{itemize}
|
||||
\item \textbf{Priors}:
|
||||
\begin{align*}
|
||||
\alpha &\sim \mathcal{N}(0, 10) \quad \text{(Intercept)} \\
|
||||
\beta &\sim \mathcal{N}(0, 10) \quad \text{(Slope)} \\
|
||||
\sigma^2 &\sim \text{Inv-Gamma}(1, 1) \quad \text{(Noise)}
|
||||
\end{align*}
|
||||
\end{itemize}
|
||||
|
||||
\subsubsection*{Data Simulation}
|
||||
Data was generated with:
|
||||
\begin{itemize}
|
||||
\item True parameters: \(\alpha = 2.3\), \(\beta = 4.0\), \(\sigma = 2.0\)
|
||||
\item \(N = 100\) observations, \(x \sim \mathcal{N}(0, 1)\), \(y = \alpha + \beta x + \mathcal{N}(0, \sigma^2)\)
|
||||
\end{itemize}
|
||||
|
||||
\subsection*{Results}
|
||||
\subsubsection*{Posterior Estimates (\(N = 100\))}
|
||||
\begin{table}[h]
|
||||
\centering
|
||||
\begin{tabular}{@{}lccc@{}}
|
||||
\toprule
|
||||
Parameter & Posterior Mean & 95\% HDI & True Value \\
|
||||
\midrule
|
||||
\(\alpha\) (Intercept) & 2.31 & [1.94, 2.65] & 2.3 \\
|
||||
\(\beta\) (Slope) & 3.71 & [3.32, 4.13] & 4.0 \\
|
||||
\(\sigma\) (Noise) & 1.91 & [1.67, 2.18] & 2.0 \\
|
||||
\bottomrule
|
||||
\end{tabular}
|
||||
\caption{Posterior summaries vs. true values. HDI = Highest Density Interval.}
|
||||
\end{table}
|
||||
|
||||
\subsubsection*{Convergence Diagnostics}
|
||||
\begin{itemize}
|
||||
\item \textbf{R-hat}: 1.0 for all parameters (ideal: \(\leq 1.01\)).
|
||||
\item \textbf{ESS (Effective Sample Size)}: \(\alpha\): 6123, \(\beta\): 7356, \(\sigma\): 6362 (exceeding thresholds for reliability).
|
||||
\end{itemize}
|
||||
|
||||
\begin{figure}[h]
|
||||
\centering
|
||||
\includegraphics[width=0.8\textwidth]{posterior_plots.png}
|
||||
\caption{Posterior distributions for \(\alpha\), \(\beta\), and \(\sigma\). Dashed lines indicate true values.}
|
||||
\end{figure}
|
||||
|
||||
\subsubsection*{Effect of Increased Data (\(N = 1000\), Hypothetical)}
|
||||
\begin{itemize}
|
||||
\item Expected uncertainty reduction: Credible interval widths shrink by \(\sim 60\%\).
|
||||
\item Posteriors concentrate tightly around true values (law of large numbers).
|
||||
\end{itemize}
|
||||
|
||||
\subsection*{Discussion}
|
||||
\subsubsection*{Accuracy and Uncertainty}
|
||||
\begin{itemize}
|
||||
\item With \(N = 100\), estimates align closely with ground truth (e.g., \(\beta = 3.71\) vs. true \(4.0\)), but credible intervals reflect residual uncertainty.
|
||||
\item Noise (\(\sigma\)) slightly underestimated but within plausible range.
|
||||
\end{itemize}
|
||||
|
||||
\subsubsection*{Model Insights}
|
||||
\begin{itemize}
|
||||
\item Noodles expand by \(\sim 3.7\) units per second (\(\beta\)), validating the hypothesis.
|
||||
\item Stan's MCMC sampler achieved excellent convergence (R-hat = 1.0, ESS > 5000).
|
||||
\end{itemize}
|
||||
|
||||
\subsubsection*{Limitations}
|
||||
\begin{itemize}
|
||||
\item Assumes linearity and normality; real-world noodle expansion may exhibit nonlinear dynamics.
|
||||
\item Hyperparameters (e.g., \(\mathcal{N}(0, 10)\)) chosen for demonstration, not domain knowledge.
|
||||
\end{itemize}
|
||||
@@ -0,0 +1,37 @@
|
||||
data {
|
||||
int<lower=1> N;
|
||||
array[N] real<lower=0> y;
|
||||
array[N] int<lower=1, upper=2> condition;
|
||||
array[N] int<lower=0, upper=1> choice;
|
||||
}
|
||||
|
||||
parameters {
|
||||
// Your code here
|
||||
}
|
||||
|
||||
model {
|
||||
// Priors
|
||||
// Your code here
|
||||
|
||||
// Likelihood
|
||||
for (n in 1:N) {
|
||||
// Condition 1
|
||||
if (condition[n] == 1) {
|
||||
if (choice[n] == 1) {
|
||||
// Your code here
|
||||
}
|
||||
else {
|
||||
// Your code here
|
||||
}
|
||||
}
|
||||
// Condition 2
|
||||
if (condition[n] == 2) {
|
||||
if (choice[n] == 1) {
|
||||
// Your code here
|
||||
}
|
||||
else {
|
||||
// Your code here
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
After Width: | Height: | Size: 21 KiB |
@@ -0,0 +1,301 @@
|
||||
rt;choice;condition
|
||||
0.477;1.0;1.0
|
||||
0.6;1.0;1.0
|
||||
0.5;0.0;1.0
|
||||
0.416;1.0;1.0
|
||||
0.435;1.0;1.0
|
||||
0.499;1.0;1.0
|
||||
0.531;1.0;1.0
|
||||
0.616;1.0;1.0
|
||||
0.492;1.0;1.0
|
||||
0.682;1.0;1.0
|
||||
0.525;1.0;1.0
|
||||
0.714;1.0;1.0
|
||||
0.467;0.0;1.0
|
||||
1.106;1.0;1.0
|
||||
0.427;1.0;1.0
|
||||
0.681;1.0;1.0
|
||||
0.438;1.0;1.0
|
||||
0.584;0.0;1.0
|
||||
0.461;1.0;1.0
|
||||
0.466;1.0;1.0
|
||||
0.488;1.0;1.0
|
||||
0.431;1.0;1.0
|
||||
0.501;1.0;1.0
|
||||
0.444;1.0;1.0
|
||||
0.496;1.0;1.0
|
||||
0.5;1.0;1.0
|
||||
0.716;1.0;1.0
|
||||
0.449;1.0;1.0
|
||||
0.45;1.0;1.0
|
||||
0.552;1.0;1.0
|
||||
0.479;1.0;1.0
|
||||
0.497;1.0;1.0
|
||||
0.463;1.0;1.0
|
||||
0.54;0.0;1.0
|
||||
0.44;1.0;1.0
|
||||
0.425;1.0;1.0
|
||||
0.554;1.0;1.0
|
||||
0.663;1.0;1.0
|
||||
0.434;1.0;1.0
|
||||
0.463;1.0;1.0
|
||||
0.423;1.0;1.0
|
||||
0.423;1.0;1.0
|
||||
0.45;1.0;1.0
|
||||
0.687;1.0;1.0
|
||||
0.587;1.0;1.0
|
||||
0.584;1.0;1.0
|
||||
0.531;1.0;1.0
|
||||
0.718;1.0;1.0
|
||||
0.534;1.0;1.0
|
||||
0.565;1.0;1.0
|
||||
0.43;1.0;1.0
|
||||
0.505;0.0;1.0
|
||||
0.456;1.0;1.0
|
||||
0.668;1.0;1.0
|
||||
0.459;1.0;1.0
|
||||
0.509;1.0;1.0
|
||||
0.506;1.0;1.0
|
||||
0.741;1.0;1.0
|
||||
0.633;1.0;1.0
|
||||
0.475;1.0;1.0
|
||||
0.635;1.0;1.0
|
||||
0.456;1.0;1.0
|
||||
0.466;1.0;1.0
|
||||
0.567;1.0;1.0
|
||||
0.449;1.0;1.0
|
||||
0.451;1.0;1.0
|
||||
0.464;1.0;1.0
|
||||
0.467;1.0;1.0
|
||||
0.559;1.0;1.0
|
||||
0.425;1.0;1.0
|
||||
0.452;1.0;1.0
|
||||
0.411;1.0;1.0
|
||||
0.528;1.0;1.0
|
||||
0.429;1.0;1.0
|
||||
0.521;1.0;1.0
|
||||
0.54;0.0;1.0
|
||||
0.652;1.0;1.0
|
||||
0.687;1.0;1.0
|
||||
0.57;1.0;1.0
|
||||
0.484;0.0;1.0
|
||||
0.545;1.0;1.0
|
||||
0.479;1.0;1.0
|
||||
0.68;1.0;1.0
|
||||
0.434;1.0;1.0
|
||||
0.458;1.0;1.0
|
||||
0.501;1.0;1.0
|
||||
0.509;1.0;1.0
|
||||
0.462;1.0;1.0
|
||||
0.452;1.0;1.0
|
||||
0.522;1.0;1.0
|
||||
0.431;1.0;1.0
|
||||
0.43;1.0;1.0
|
||||
0.49;1.0;1.0
|
||||
0.697;1.0;1.0
|
||||
0.633;1.0;1.0
|
||||
0.539;1.0;1.0
|
||||
0.483;1.0;1.0
|
||||
1.11;1.0;1.0
|
||||
0.472;1.0;1.0
|
||||
0.757;1.0;1.0
|
||||
0.854;1.0;1.0
|
||||
0.653;1.0;1.0
|
||||
0.45;1.0;1.0
|
||||
0.516;1.0;1.0
|
||||
0.547;0.0;1.0
|
||||
0.432;1.0;1.0
|
||||
0.483;1.0;1.0
|
||||
0.501;1.0;1.0
|
||||
0.444;1.0;1.0
|
||||
0.515;1.0;1.0
|
||||
0.534;1.0;1.0
|
||||
0.441;1.0;1.0
|
||||
0.474;1.0;1.0
|
||||
0.513;1.0;1.0
|
||||
0.589;0.0;1.0
|
||||
0.446;1.0;1.0
|
||||
0.642;0.0;1.0
|
||||
0.591;1.0;1.0
|
||||
0.64;1.0;1.0
|
||||
0.449;1.0;1.0
|
||||
0.418;1.0;1.0
|
||||
0.615;1.0;1.0
|
||||
0.585;1.0;1.0
|
||||
0.459;1.0;1.0
|
||||
0.479;1.0;1.0
|
||||
0.477;1.0;1.0
|
||||
0.559;1.0;1.0
|
||||
0.419;1.0;1.0
|
||||
0.522;1.0;1.0
|
||||
0.429;1.0;1.0
|
||||
0.528;1.0;1.0
|
||||
0.467;1.0;1.0
|
||||
0.58;0.0;1.0
|
||||
0.487;1.0;1.0
|
||||
0.451;1.0;1.0
|
||||
0.527;1.0;1.0
|
||||
0.451;1.0;1.0
|
||||
0.49;1.0;1.0
|
||||
0.514;1.0;1.0
|
||||
0.455;1.0;1.0
|
||||
0.507;1.0;1.0
|
||||
0.474;1.0;1.0
|
||||
0.458;1.0;1.0
|
||||
0.454;1.0;1.0
|
||||
0.518;1.0;1.0
|
||||
0.429;1.0;1.0
|
||||
0.96;1.0;1.0
|
||||
0.427;1.0;1.0
|
||||
0.802;1.0;1.0
|
||||
0.446;1.0;1.0
|
||||
0.439;0.0;2.0
|
||||
0.471;0.0;2.0
|
||||
0.917;0.0;2.0
|
||||
0.562;1.0;2.0
|
||||
0.678;0.0;2.0
|
||||
0.671;1.0;2.0
|
||||
0.599;0.0;2.0
|
||||
0.638;0.0;2.0
|
||||
0.494;0.0;2.0
|
||||
0.498;1.0;2.0
|
||||
0.582;0.0;2.0
|
||||
0.672;1.0;2.0
|
||||
0.449;1.0;2.0
|
||||
0.585;0.0;2.0
|
||||
0.514;1.0;2.0
|
||||
0.493;1.0;2.0
|
||||
0.437;0.0;2.0
|
||||
0.452;1.0;2.0
|
||||
0.727;0.0;2.0
|
||||
0.523;1.0;2.0
|
||||
0.485;1.0;2.0
|
||||
0.439;1.0;2.0
|
||||
0.683;0.0;2.0
|
||||
0.578;1.0;2.0
|
||||
0.431;1.0;2.0
|
||||
0.562;0.0;2.0
|
||||
0.471;1.0;2.0
|
||||
0.786;1.0;2.0
|
||||
0.434;1.0;2.0
|
||||
0.441;1.0;2.0
|
||||
0.745;1.0;2.0
|
||||
0.533;1.0;2.0
|
||||
0.756;0.0;2.0
|
||||
0.678;1.0;2.0
|
||||
0.494;1.0;2.0
|
||||
1.028;1.0;2.0
|
||||
0.475;0.0;2.0
|
||||
0.563;0.0;2.0
|
||||
0.483;1.0;2.0
|
||||
0.566;0.0;2.0
|
||||
0.466;1.0;2.0
|
||||
1.086;1.0;2.0
|
||||
0.573;1.0;2.0
|
||||
0.597;1.0;2.0
|
||||
0.597;0.0;2.0
|
||||
0.446;1.0;2.0
|
||||
0.437;1.0;2.0
|
||||
0.515;1.0;2.0
|
||||
0.524;0.0;2.0
|
||||
0.513;1.0;2.0
|
||||
0.465;1.0;2.0
|
||||
0.704;1.0;2.0
|
||||
0.801;1.0;2.0
|
||||
0.484;0.0;2.0
|
||||
0.459;0.0;2.0
|
||||
0.576;0.0;2.0
|
||||
0.462;1.0;2.0
|
||||
0.471;0.0;2.0
|
||||
0.595;1.0;2.0
|
||||
0.464;1.0;2.0
|
||||
0.644;1.0;2.0
|
||||
0.42;0.0;2.0
|
||||
0.452;1.0;2.0
|
||||
0.488;0.0;2.0
|
||||
0.568;1.0;2.0
|
||||
0.481;0.0;2.0
|
||||
0.5;1.0;2.0
|
||||
0.54;1.0;2.0
|
||||
0.447;0.0;2.0
|
||||
0.463;1.0;2.0
|
||||
0.507;1.0;2.0
|
||||
0.522;1.0;2.0
|
||||
0.58;1.0;2.0
|
||||
0.464;0.0;2.0
|
||||
0.507;0.0;2.0
|
||||
0.727;1.0;2.0
|
||||
0.452;1.0;2.0
|
||||
0.636;0.0;2.0
|
||||
0.552;1.0;2.0
|
||||
0.739;1.0;2.0
|
||||
0.468;1.0;2.0
|
||||
0.563;1.0;2.0
|
||||
0.443;1.0;2.0
|
||||
1.023;1.0;2.0
|
||||
0.571;1.0;2.0
|
||||
0.44;0.0;2.0
|
||||
0.717;1.0;2.0
|
||||
0.751;1.0;2.0
|
||||
0.491;0.0;2.0
|
||||
0.456;1.0;2.0
|
||||
0.569;1.0;2.0
|
||||
0.456;1.0;2.0
|
||||
0.517;1.0;2.0
|
||||
0.492;1.0;2.0
|
||||
0.527;1.0;2.0
|
||||
0.501;0.0;2.0
|
||||
0.499;0.0;2.0
|
||||
0.428;0.0;2.0
|
||||
0.529;0.0;2.0
|
||||
0.43;0.0;2.0
|
||||
0.453;0.0;2.0
|
||||
0.484;0.0;2.0
|
||||
0.541;1.0;2.0
|
||||
0.707;0.0;2.0
|
||||
0.712;1.0;2.0
|
||||
0.53;0.0;2.0
|
||||
0.871;1.0;2.0
|
||||
0.896;1.0;2.0
|
||||
0.548;0.0;2.0
|
||||
0.484;1.0;2.0
|
||||
0.779;1.0;2.0
|
||||
0.503;0.0;2.0
|
||||
0.696;0.0;2.0
|
||||
0.522;0.0;2.0
|
||||
0.93;1.0;2.0
|
||||
0.535;0.0;2.0
|
||||
0.615;1.0;2.0
|
||||
0.624;1.0;2.0
|
||||
0.742;0.0;2.0
|
||||
0.528;1.0;2.0
|
||||
0.441;1.0;2.0
|
||||
0.514;0.0;2.0
|
||||
0.445;0.0;2.0
|
||||
0.625;0.0;2.0
|
||||
0.578;1.0;2.0
|
||||
0.55;1.0;2.0
|
||||
0.686;1.0;2.0
|
||||
0.505;1.0;2.0
|
||||
0.872;1.0;2.0
|
||||
0.548;1.0;2.0
|
||||
0.487;0.0;2.0
|
||||
0.733;0.0;2.0
|
||||
0.46;0.0;2.0
|
||||
0.764;1.0;2.0
|
||||
0.589;0.0;2.0
|
||||
0.482;0.0;2.0
|
||||
0.449;0.0;2.0
|
||||
0.428;0.0;2.0
|
||||
0.604;1.0;2.0
|
||||
0.505;1.0;2.0
|
||||
0.649;1.0;2.0
|
||||
0.484;1.0;2.0
|
||||
0.535;0.0;2.0
|
||||
0.471;0.0;2.0
|
||||
0.441;0.0;2.0
|
||||
0.528;0.0;2.0
|
||||
0.621;0.0;2.0
|
||||
0.48;1.0;2.0
|
||||
0.693;1.0;2.0
|
||||
0.493;1.0;2.0
|
||||
|
|
After Width: | Height: | Size: 378 KiB |
|
After Width: | Height: | Size: 50 KiB |
|
After Width: | Height: | Size: 48 KiB |
|
After Width: | Height: | Size: 55 KiB |
@@ -0,0 +1,215 @@
|
||||
alabaster==1.0.0
|
||||
annotated-types==0.7.0
|
||||
anyascii==0.3.2
|
||||
anyio==4.8.0
|
||||
apparmor==4.0.3
|
||||
application-utility==1.4.0
|
||||
arandr==0.1.11
|
||||
argcomplete==3.5.3
|
||||
arrow==1.3.0
|
||||
attrs==23.2.1.dev0
|
||||
autocommand==2.2.2
|
||||
Automat==22.10.0
|
||||
b2==4.3.0
|
||||
b2sdk==2.7.0
|
||||
Babel==2.15.0
|
||||
black==25.1.0
|
||||
breezy==3.3.9
|
||||
btrfsutil==6.13
|
||||
build==1.2.2
|
||||
CacheControl==0.14.2
|
||||
cachy==0.3.0
|
||||
certifi==2025.1.31
|
||||
cffi==1.17.1
|
||||
cfgv==3.4.0
|
||||
chardet==5.2.0
|
||||
charset-normalizer==3.4.1
|
||||
cleo==2.1.0
|
||||
click==8.1.8
|
||||
configobj==5.0.9
|
||||
constantly==23.10.4
|
||||
crashtest==0.4.1
|
||||
cryptography==44.0.2
|
||||
cssselect2==0.7.0
|
||||
Cython==0.29.37
|
||||
dbus-python==1.4.0
|
||||
deluge==2.1.1
|
||||
distlib==0.3.9
|
||||
distro==1.9.0
|
||||
docopt==0.6.2
|
||||
docutils==0.21.2
|
||||
dulwich==0.22.8
|
||||
ecdsa==0.19.0
|
||||
editables==0.5
|
||||
fastbencode==0.3.1
|
||||
fastjsonschema==2.21.1
|
||||
filelock==3.18.0
|
||||
findpython==0.6.3
|
||||
flit_core==3.11.0
|
||||
GeoIP==1.3.2
|
||||
gpg==1.24.2
|
||||
gufw==24.4.0
|
||||
h11==0.14.0
|
||||
hatch==1.14.0
|
||||
hatchling==1.27.0
|
||||
html5lib==1.1
|
||||
httpcore==1.0.7
|
||||
httplib2==0.22.0
|
||||
httpx==0.28.1
|
||||
hyperlink==21.0.0
|
||||
identify==2.6.9
|
||||
idna==3.10
|
||||
ifaddr==0.2.0
|
||||
imagesize==1.4.1
|
||||
importlib_metadata==7.2.1
|
||||
incremental==22.10.0
|
||||
inflect==7.5.0
|
||||
iniconfig==2.0.0
|
||||
installer==0.7.0
|
||||
jaraco.classes==3.4.0
|
||||
jaraco.collections==5.1.0
|
||||
jaraco.context==6.0.1
|
||||
jaraco.functools==4.1.0
|
||||
jaraco.text==4.0.0
|
||||
jeepney==0.9.0
|
||||
Jinja2==3.1.5
|
||||
jsonschema==4.23.0
|
||||
jsonschema-specifications==2024.10.1
|
||||
keyring==25.6.0
|
||||
lark==1.2.2
|
||||
launchpadlib==2.0.0
|
||||
lazr.restfulclient==0.14.6
|
||||
lazr.uri==1.0.6
|
||||
legacy-cgi==2.6.2
|
||||
lensfun==0.3.4
|
||||
LibAppArmor==4.0.3
|
||||
libtorrent==2.0.11
|
||||
lit==19.1.7.dev0
|
||||
lockfile==0.12.2
|
||||
logfury==1.0.1
|
||||
lxml==5.3.1
|
||||
Mako==1.3.9.dev0
|
||||
Markdown==3.7
|
||||
markdown-it-py==3.0.0
|
||||
MarkupSafe==2.1.5
|
||||
material-color-utilities-python==0.1.5
|
||||
mdurl==0.1.2
|
||||
merge3==0.0.15
|
||||
meson==1.7.0
|
||||
more-itertools==10.6.0
|
||||
msgpack==1.0.5
|
||||
mypy_extensions==1.0.0
|
||||
netifaces==0.11.0
|
||||
nftables==0.1
|
||||
nodeenv==1.9.1
|
||||
npyscreen==4.10.5
|
||||
oauthlib==3.2.2
|
||||
ordered-set==4.1.0
|
||||
packaging==24.2
|
||||
pacman_mirrors==4.27
|
||||
pathspec==0.12.1
|
||||
patiencediff==0.2.15
|
||||
pbs-installer==2025.3.17
|
||||
pefile==2024.8.26
|
||||
pexpect==4.9.0
|
||||
phx-class-registry==4.0.6
|
||||
pillow==11.1.0
|
||||
pkginfo==1.12.0
|
||||
platformdirs==4.3.6
|
||||
pluggy==1.5.0
|
||||
poetry==2.1.1
|
||||
poetry-core==2.1.1
|
||||
poetry-plugin-export==1.9.0
|
||||
pre_commit==4.1.0
|
||||
psutil==7.0.0
|
||||
ptyprocess==0.7.0
|
||||
pyasn1==0.6.0
|
||||
pyasn1_modules==0.4.0
|
||||
pycairo==1.27.0
|
||||
pycparser==2.22
|
||||
pydantic==2.10.6
|
||||
pydantic_core==2.27.2
|
||||
pygame_sdl2==2.1.0
|
||||
Pygments==2.19.1
|
||||
PyGObject==3.52.3
|
||||
pynotify==1.3.0
|
||||
pyOpenSSL==25.0.0
|
||||
pyparsing==3.2.1
|
||||
pyproject_hooks==1.2.0
|
||||
PyQt5==5.15.11
|
||||
PyQt5_sip==12.17.0
|
||||
PyQt6==6.8.1
|
||||
PyQt6_sip==13.10.0
|
||||
PySocks==1.7.1
|
||||
pytest==8.3.5
|
||||
python-dateutil==2.9.0
|
||||
pytz==2025.1
|
||||
pyxdg==0.28
|
||||
PyYAML==6.0.2
|
||||
ranger-fm==1.9.4
|
||||
RapidFuzz==3.12.2
|
||||
referencing==0.35.1
|
||||
regex==2024.11.6
|
||||
rencode==1.0.6
|
||||
reportlab==4.2.2
|
||||
requests==2.32.3
|
||||
requests-toolbelt==1.0.0
|
||||
rich==13.9.4
|
||||
roman-numerals-py==3.1.0
|
||||
rpds-py==0.22.3
|
||||
rsa==4.9
|
||||
rst2ansi==0.1.5
|
||||
SecretStorage==3.3.3
|
||||
service-identity==24.2.0
|
||||
setproctitle==1.3.5
|
||||
setuptools==75.8.0
|
||||
setuptools-scm==8.2.1
|
||||
shellingham==1.5.4
|
||||
simplejson==3.20.1
|
||||
six==1.17.0
|
||||
sniffio==1.3.1
|
||||
snowballstemmer==2.2.0
|
||||
speedtest-cli==2.1.3
|
||||
Sphinx==8.2.3
|
||||
sphinx_rtd_dark_mode==1.3.0
|
||||
sphinx_rtd_theme==2.0.0
|
||||
sphinxcontrib-applehelp==2.0.0
|
||||
sphinxcontrib-devhelp==2.0.0
|
||||
sphinxcontrib-htmlhelp==2.1.0
|
||||
sphinxcontrib-jquery==4.1
|
||||
sphinxcontrib-jsmath==1.0.1
|
||||
sphinxcontrib-qthelp==2.0.0
|
||||
sphinxcontrib-serializinghtml==2.0.0
|
||||
svglib==1.5.1
|
||||
systemd-python==235
|
||||
tabulate==0.9.0
|
||||
termcolor==2.5.0
|
||||
tinycss2==1.4.0
|
||||
tomli==2.0.1
|
||||
tomli_w==1.2.0
|
||||
tomlkit==0.13.2
|
||||
torbrowser-launcher==0.3.7
|
||||
tqdm==4.67.1
|
||||
trove-classifiers==2025.3.19.19
|
||||
Twisted==24.3.0
|
||||
typeguard==4.4.2
|
||||
types-python-dateutil==2.9.0.20241206
|
||||
typing_extensions==4.12.2
|
||||
tzlocal==5.3.1
|
||||
ueberzug==18.3.1
|
||||
ufw==0.36.2
|
||||
urllib3==2.3.0
|
||||
userpath==1.9.2
|
||||
uv==0.6.10
|
||||
validate==5.0.9
|
||||
validate-pyproject==0.24.1
|
||||
virtualenv==20.28.0
|
||||
wadllib==2.0.0
|
||||
webencodings==0.5.1
|
||||
wheel==0.45.1
|
||||
woeusb-ng==0.2.12
|
||||
wxPython==4.2.2
|
||||
Yapsy==2.0.0
|
||||
zipp==3.21.0
|
||||
zope.interface==7.2
|
||||
zstandard==0.23.0
|
||||
@@ -0,0 +1,12 @@
|
||||
all: run analyze doc
|
||||
|
||||
run:
|
||||
python main.py
|
||||
|
||||
analyze:
|
||||
python main.py --analyze
|
||||
|
||||
doc:
|
||||
pdflatex report.tex
|
||||
pdflatex report.tex
|
||||
pdflatex report.tex
|
||||
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.974456560897827,0.26922,1.7304970249176026,0.3786
|
||||
2,1.8454077367782593,0.31724,1.6601378639221192,0.385
|
||||
3,1.8263939748764038,0.32684,1.6606147064208985,0.3904
|
||||
4,1.8094575137710571,0.33182,1.665997783279419,0.3756
|
||||
5,1.7877959413909912,0.3392,1.58222125415802,0.4152
|
||||
6,1.787890262145996,0.34048,1.7176293952941895,0.3643
|
||||
7,1.7759525202178954,0.34648,1.6122663572311402,0.4023
|
||||
8,1.7836139296722413,0.3439,1.7781054580688476,0.341
|
||||
9,1.7774498300933839,0.34556,1.5754944620132447,0.4182
|
||||
10,1.7656044020462036,0.34748,1.5426378156661986,0.4295
|
||||
11,1.7746756098556518,0.34388,1.5636182432174683,0.4233
|
||||
12,1.7580935064315797,0.354,1.5672656684875488,0.4113
|
||||
13,1.7512612063217163,0.35642,1.6501448486328125,0.3891
|
||||
14,1.7724904922485352,0.3475,1.5432163452148437,0.4352
|
||||
15,1.7634696194076538,0.35212,1.5920137149810791,0.4202
|
||||
16,1.7557356859970092,0.35386,1.5618241235733032,0.4283
|
||||
17,1.7429854767227173,0.3581,1.5591908136367798,0.4274
|
||||
18,1.7365073428726197,0.3617,1.5122350412368775,0.4454
|
||||
19,1.7340069851303102,0.3602,1.501157410812378,0.451
|
||||
20,1.7409456158065797,0.35886,1.531626453781128,0.4443
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.1344652434539797,0.20488,2.022770263671875,0.2597
|
||||
2,2.0186911404418946,0.24922,1.8970225284576416,0.2983
|
||||
3,2.010243484954834,0.25274,1.9072177560806274,0.29
|
||||
4,1.9952284572601318,0.26088,1.8907000495910644,0.3027
|
||||
5,1.9962889101791381,0.26108,1.8968271499633789,0.3007
|
||||
6,1.9906421136856078,0.25978,1.8710549690246583,0.2991
|
||||
7,1.9789404233551025,0.26648,1.8729571523666382,0.3107
|
||||
8,1.97843787651062,0.26624,1.8582117340087891,0.3136
|
||||
9,1.9777873904037475,0.26486,1.862920560836792,0.3152
|
||||
10,1.9690655079650878,0.26784,1.8784164529800416,0.3076
|
||||
11,1.9658676456069946,0.27026,1.8229821308135987,0.3258
|
||||
12,1.9590900508117677,0.26692,1.8474713500976563,0.3287
|
||||
13,1.9669304356384278,0.26908,1.8680181037902832,0.3111
|
||||
14,1.958348878631592,0.27156,1.8127110235214234,0.3475
|
||||
15,1.9552243856430054,0.27324,1.8901950101852416,0.2852
|
||||
16,1.956668193511963,0.27416,1.8504865715026855,0.3122
|
||||
17,1.948801014099121,0.27538,1.7920600860595703,0.3394
|
||||
18,1.9498589097595216,0.27362,1.8019978916168213,0.3464
|
||||
19,1.9449512919616698,0.27842,1.8016673805236816,0.3433
|
||||
20,1.9463943926620484,0.27408,1.8101041078567506,0.3376
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.847881386680603,0.32942,1.5511522064208985,0.4312
|
||||
2,1.6177299839401245,0.41082,1.4712652378082276,0.4657
|
||||
3,1.5569700075531006,0.43428,1.394599464225769,0.4915
|
||||
4,1.5289771292495729,0.4452,1.365533829307556,0.5026
|
||||
5,1.5117075008392333,0.45238,1.3679044742584228,0.5022
|
||||
6,1.4944573168182373,0.45804,1.3945600986480713,0.4908
|
||||
7,1.4954497344589233,0.46082,1.3769316219329835,0.5027
|
||||
8,1.4939354167556762,0.4581,1.4132782371520995,0.5003
|
||||
9,1.4871603671264648,0.46338,1.3577482885360719,0.4995
|
||||
10,1.4732131560516357,0.46792,1.4151488857269288,0.4874
|
||||
11,1.4649469945526123,0.468,1.3565514822006226,0.514
|
||||
12,1.4616536054229736,0.47348,1.3092844657897948,0.5245
|
||||
13,1.4838773123168945,0.46748,1.3107160717010498,0.5241
|
||||
14,1.470268120613098,0.47032,1.3532153959274291,0.514
|
||||
15,1.4701028344345093,0.47292,1.3349319314956665,0.5207
|
||||
16,1.4733314770889283,0.46694,1.30636460647583,0.5253
|
||||
17,1.4659587969207764,0.47168,1.3394561473846436,0.5188
|
||||
18,1.465981600227356,0.4697,1.3647231616973876,0.5147
|
||||
19,1.4668639030456543,0.46984,1.3287893884658812,0.5175
|
||||
20,1.4546778205108644,0.47452,1.3466038597106933,0.5132
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.9015797283554077,0.3039,1.6871370712280274,0.3947
|
||||
2,1.6032403519821168,0.4152,1.4519082370758056,0.4649
|
||||
3,1.5410432431793213,0.44184,1.371306716156006,0.4915
|
||||
4,1.5103425858688355,0.45614,1.4346882278442383,0.4913
|
||||
5,1.5013815983963013,0.456,1.3663070734024048,0.4951
|
||||
6,1.5020151425552368,0.45416,1.36682343044281,0.5039
|
||||
7,1.49663742893219,0.45978,1.323156630897522,0.5232
|
||||
8,1.477861526107788,0.46658,1.3874903659820557,0.5081
|
||||
9,1.472189111251831,0.46874,1.3448004537582396,0.5121
|
||||
10,1.4770725217819214,0.46474,1.341456114578247,0.5166
|
||||
11,1.4564127297210694,0.47708,1.2886124462127686,0.5293
|
||||
12,1.5129934811019898,0.45428,1.3219362440109252,0.5281
|
||||
13,1.4658231272125244,0.4736,1.3313484254837036,0.522
|
||||
14,1.4768860134887696,0.46834,1.333654520225525,0.5173
|
||||
15,1.4547452118301392,0.4728,1.3101324228286744,0.5272
|
||||
16,1.4501867012405396,0.47806,1.3292738021850585,0.5192
|
||||
17,1.461766162033081,0.47254,1.3780369548797606,0.5053
|
||||
18,1.447611441307068,0.47654,1.277954197692871,0.547
|
||||
19,1.4491669342803954,0.47664,1.2802382738113403,0.5409
|
||||
20,1.4606599185180664,0.4743,1.3675791484832764,0.4984
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.8454712069320678,0.32346,1.5677360635757447,0.4279
|
||||
2,1.6003490134811402,0.41786,1.4861061136245728,0.4632
|
||||
3,1.5371953969573975,0.44236,1.4539190216064453,0.4733
|
||||
4,1.524127781944275,0.45006,1.3829264806747437,0.505
|
||||
5,1.5006125688171388,0.45966,1.3783319478988647,0.502
|
||||
6,1.4872980946731567,0.46374,1.3886339023590089,0.5019
|
||||
7,1.4861942530059815,0.46,1.4412277736663819,0.4721
|
||||
8,1.5050456002044679,0.458,1.3324725828170776,0.5218
|
||||
9,1.4739490990066528,0.46516,1.323778917312622,0.5296
|
||||
10,1.4616954332733154,0.47098,1.3400814496994018,0.5201
|
||||
11,1.4587705549240113,0.47354,1.309501482772827,0.5378
|
||||
12,1.4562981714630128,0.47654,1.3954745462417601,0.5143
|
||||
13,1.450860862197876,0.4802,1.3416914821624757,0.5187
|
||||
14,1.4455588480377197,0.48014,1.2786856636047363,0.548
|
||||
15,1.4538489317321777,0.48116,1.315366445350647,0.5394
|
||||
16,1.4860243936157227,0.46902,1.3105680742263794,0.5308
|
||||
17,1.4425668993759155,0.4783,1.3113457733154297,0.5372
|
||||
18,1.4487706513214111,0.4759,1.3786734741210938,0.5051
|
||||
19,1.4291919551467895,0.48284,1.3119900798797608,0.5372
|
||||
20,1.4437602603912354,0.48114,1.259741804122925,0.556
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.0573630228042603,0.2491,1.8283799295425416,0.3183
|
||||
2,1.808153849182129,0.3266,1.6456665771484376,0.378
|
||||
3,1.753247998008728,0.34604,1.6158345397949219,0.3985
|
||||
4,1.721941413230896,0.36686,1.5776695816040038,0.4228
|
||||
5,1.7048808489990235,0.37028,1.5212398496627808,0.4512
|
||||
6,1.6865429331207276,0.37936,1.481686958694458,0.449
|
||||
7,1.6720955381011964,0.3858,1.5646504203796388,0.4257
|
||||
8,1.6756063533782959,0.38306,1.5534727085113524,0.4274
|
||||
9,1.6735252333068849,0.3839,1.4908922216415406,0.4686
|
||||
10,1.6613293509292602,0.3906,1.5141861679077149,0.451
|
||||
11,1.660574383468628,0.38788,1.5027179416656493,0.4561
|
||||
12,1.6742186321640016,0.38374,1.4785913095474243,0.4539
|
||||
13,1.6540580016326905,0.3938,1.4502152013778686,0.4571
|
||||
14,1.6618211807632446,0.39018,1.5403156909942628,0.4337
|
||||
15,1.6527680168533325,0.39332,1.4511047771453858,0.4789
|
||||
16,1.645100059890747,0.39646,1.43847534198761,0.4773
|
||||
17,1.6513261923599243,0.39488,1.4823436264038086,0.4453
|
||||
18,1.6383419089508056,0.40116,1.495237015724182,0.4499
|
||||
19,1.6509938947296143,0.39756,1.4486223876953126,0.4687
|
||||
20,1.6417471273422242,0.39988,1.4995165328979492,0.462
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.014116881027222,0.26518,1.7019136341094971,0.3769
|
||||
2,1.8181559704208374,0.32602,1.6809476217269896,0.3798
|
||||
3,1.7674869966888427,0.34662,1.6972094938278197,0.373
|
||||
4,1.752722897453308,0.3521,1.5924717903137207,0.4044
|
||||
5,1.748039944114685,0.35244,1.5717085536956787,0.4246
|
||||
6,1.7371159115219117,0.35314,1.6360073663711547,0.3957
|
||||
7,1.7332507349395752,0.35908,1.5807278701782226,0.4193
|
||||
8,1.7280795514678955,0.3615,1.5841178344726563,0.4158
|
||||
9,1.7286223630142212,0.35826,1.5960125846862794,0.4188
|
||||
10,1.7247096759414673,0.36294,1.580347721672058,0.4255
|
||||
11,1.72447563621521,0.36342,1.6092770568847656,0.406
|
||||
12,1.7200987982559204,0.36386,1.58438682346344,0.4252
|
||||
13,1.7571224234008789,0.35432,1.7571450429916382,0.3504
|
||||
14,1.7474284732818604,0.3558,1.620937774848938,0.3959
|
||||
15,1.7277102383804321,0.3594,1.5965348707199096,0.4276
|
||||
16,1.7246993078613282,0.36316,1.5578019842147828,0.429
|
||||
17,1.716722907333374,0.3643,1.5713054029464721,0.424
|
||||
18,1.7226240491104126,0.35982,1.5993318134307861,0.3964
|
||||
19,1.7280442880249023,0.36182,1.5774778314590454,0.4154
|
||||
20,1.7140190068817138,0.36316,1.5883642023086548,0.4128
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.074842978553772,0.2321,1.8135286460876465,0.3464
|
||||
2,1.8202483989715577,0.33832,1.601849068069458,0.4114
|
||||
3,1.6800080819702148,0.38962,1.4346690870285035,0.4786
|
||||
4,1.6108920947265626,0.41422,1.410789061355591,0.4937
|
||||
5,1.5602197430038451,0.43434,1.3864108823776244,0.4937
|
||||
6,1.5249909924697875,0.44654,1.3081797481536865,0.5289
|
||||
7,1.4851352381134033,0.4626,1.3037838243484496,0.5382
|
||||
8,1.4561926398849487,0.47614,1.2144829073905945,0.5633
|
||||
9,1.4245338591766357,0.48662,1.1912867235183715,0.5756
|
||||
10,1.404860743598938,0.49476,1.1602596988677978,0.5834
|
||||
11,1.3792334610748291,0.50492,1.1169185941696167,0.6075
|
||||
12,1.3638545850372314,0.51038,1.083746408367157,0.6156
|
||||
13,1.339594167098999,0.52216,1.0875441018104552,0.6145
|
||||
14,1.3327489657974243,0.52494,1.0794508327484131,0.6234
|
||||
15,1.314370972480774,0.5307,1.0695297025680541,0.6197
|
||||
16,1.2982412660217286,0.53638,1.0055400819778442,0.6471
|
||||
17,1.2795147993087768,0.54412,1.0264243045806885,0.6391
|
||||
18,1.2759248006820678,0.54682,1.0114541644096375,0.6519
|
||||
19,1.2656549953079224,0.5506,0.9857818864822387,0.6522
|
||||
20,1.246139160194397,0.556,0.9813527516365051,0.6549
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.084733605880737,0.23054,1.805083052444458,0.3441
|
||||
2,1.8338969733428956,0.33488,1.5635989295959474,0.443
|
||||
3,1.6894913947296142,0.3863,1.4450351499557494,0.4813
|
||||
4,1.6148385356521606,0.41256,1.4190383405685425,0.4857
|
||||
5,1.5608203945922852,0.4371,1.3209244800567628,0.5285
|
||||
6,1.5059256549453734,0.4561,1.268412045097351,0.5529
|
||||
7,1.4648617097091674,0.47256,1.2074650314331055,0.563
|
||||
8,1.434335672302246,0.4834,1.1721592065811157,0.5902
|
||||
9,1.40667203125,0.4967,1.1277678930282593,0.5973
|
||||
10,1.374456730003357,0.50802,1.145589846420288,0.5938
|
||||
11,1.3582844898986817,0.514,1.0862995162963867,0.6214
|
||||
12,1.3364373119354247,0.5204,1.0743581775665283,0.6195
|
||||
13,1.3238696041107179,0.52766,1.0620485213279725,0.6267
|
||||
14,1.3081258194732666,0.53312,1.0225723578453063,0.6412
|
||||
15,1.288412784729004,0.53812,1.0684286714553832,0.6204
|
||||
16,1.2836802099990845,0.5423,1.0006577738761901,0.6496
|
||||
17,1.2708675945281982,0.54806,1.0162318713188172,0.6482
|
||||
18,1.2681347032546997,0.54748,1.0420465433120727,0.6345
|
||||
19,1.2533282422256469,0.55522,0.963789785194397,0.6623
|
||||
20,1.2364999733352662,0.56218,0.9658513239860534,0.6628
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.0803882836151124,0.2287,1.789647598838806,0.3638
|
||||
2,1.8214346895599365,0.33986,1.5701804725646973,0.428
|
||||
3,1.6871508443450929,0.38696,1.5095456941604615,0.4568
|
||||
4,1.6141029256820678,0.41464,1.4065279888153077,0.5012
|
||||
5,1.5632627973175048,0.43602,1.3370622457504273,0.5173
|
||||
6,1.5216406153869628,0.45132,1.2780111961364746,0.5414
|
||||
7,1.4851566635894775,0.46554,1.254789323425293,0.5411
|
||||
8,1.4516518106079102,0.477,1.224256973171234,0.5632
|
||||
9,1.4208706899261474,0.48922,1.2024554361343385,0.5696
|
||||
10,1.3847549717330934,0.50358,1.1307642597198486,0.6014
|
||||
11,1.367606183242798,0.50876,1.1109876449584961,0.6076
|
||||
12,1.3521447608184813,0.51576,1.0782893854141236,0.611
|
||||
13,1.3337134902954102,0.52338,1.078674609375,0.6199
|
||||
14,1.3183710289382935,0.5275,1.0492525255203247,0.6326
|
||||
15,1.309831443862915,0.53334,1.0428013453483582,0.6284
|
||||
16,1.2975196259307862,0.5399,1.0418389265060424,0.6349
|
||||
17,1.2833650760650634,0.5448,1.1197743309020995,0.596
|
||||
18,1.2782311374664306,0.54604,0.9917507829666138,0.6546
|
||||
19,1.2591375436401366,0.55266,0.9647492932319641,0.6582
|
||||
20,1.2543545852279663,0.55702,0.9729726590156555,0.6574
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.9567125562286376,0.2868,1.6595897144317626,0.4048
|
||||
2,1.5607059629058837,0.43864,1.4107039302825928,0.4848
|
||||
3,1.393210672225952,0.49774,1.2663319400787354,0.5489
|
||||
4,1.3098828198623658,0.52998,1.205306915473938,0.5761
|
||||
5,1.2406870504379273,0.5571,1.1519119177818298,0.5984
|
||||
6,1.1775763750839234,0.58036,1.1193829996109008,0.6051
|
||||
7,1.1138219023132325,0.60238,1.0738257904052735,0.6176
|
||||
8,1.070581941833496,0.61852,1.0231942397117615,0.6371
|
||||
9,1.030339531364441,0.63532,0.9935429817199707,0.6441
|
||||
10,0.9900466967391968,0.65002,0.9499927408218384,0.6665
|
||||
11,0.9637177305793763,0.66046,0.9340457020759583,0.6675
|
||||
12,0.9254747213745117,0.67558,0.9082785024642944,0.6792
|
||||
13,0.908202090473175,0.6779,0.9268069096565247,0.6748
|
||||
14,0.8821638398742676,0.68774,0.8835061459541321,0.6887
|
||||
15,0.8548155406951904,0.69706,0.8872944156646728,0.6871
|
||||
16,0.8442007188415528,0.70044,0.8849920805931091,0.6918
|
||||
17,0.8165878020095825,0.70872,0.8710031011581421,0.6989
|
||||
18,0.7962540114593506,0.718,0.8807716882705688,0.6914
|
||||
19,0.7728271597862244,0.72794,0.852915700340271,0.7069
|
||||
20,0.758258864402771,0.7303,0.8573900122642517,0.7003
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.0236042365264892,0.257,1.723145390892029,0.3901
|
||||
2,1.5985155703353882,0.4249,1.4042083190917969,0.4924
|
||||
3,1.4218818955612182,0.48632,1.2928504508972167,0.5413
|
||||
4,1.3203161227035523,0.52612,1.2168406352996826,0.5661
|
||||
5,1.2403734153366088,0.55756,1.149209902381897,0.5939
|
||||
6,1.1671098303222656,0.58366,1.1203332425117494,0.5994
|
||||
7,1.112549086380005,0.6046,1.0411104937553406,0.6325
|
||||
8,1.0569730423927308,0.62496,1.014323724079132,0.6456
|
||||
9,1.0219695941925049,0.6371,0.9721902985572815,0.659
|
||||
10,0.9759654503250123,0.65288,0.942803035736084,0.6697
|
||||
11,0.9441324331092834,0.66574,0.9112948065757751,0.679
|
||||
12,0.9140988022041321,0.67698,0.8975791595458984,0.6881
|
||||
13,0.8857126325798035,0.68738,0.8933684177398682,0.689
|
||||
14,0.859081026725769,0.69532,0.8806665014266968,0.6958
|
||||
15,0.8373628576278687,0.7053,0.8862028203964233,0.6894
|
||||
16,0.8062883552742004,0.71626,0.8928611957550049,0.6907
|
||||
17,0.7948164237213134,0.71876,0.8475546907424927,0.7084
|
||||
18,0.7735403092384339,0.72452,0.8641609072685241,0.7011
|
||||
19,0.7580092010688781,0.7301,0.8455605938911438,0.7051
|
||||
20,0.7356696060180664,0.73766,0.8630998106002807,0.704
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.9932699127578735,0.26908,1.652609211730957,0.4115
|
||||
2,1.6025463925170897,0.42024,1.4165557703018188,0.4886
|
||||
3,1.4385107410812379,0.47978,1.3373357730865478,0.5276
|
||||
4,1.329945457725525,0.52296,1.207246643447876,0.5639
|
||||
5,1.2507697713470458,0.55224,1.1396601531982422,0.5997
|
||||
6,1.1825159574127198,0.5785,1.0910413694381713,0.6131
|
||||
7,1.129948296546936,0.59706,1.0626532356262206,0.6223
|
||||
8,1.0810239276123046,0.61532,1.0606243711471557,0.6271
|
||||
9,1.0398018453979492,0.63036,0.9847188907623291,0.6508
|
||||
10,1.0001467765808105,0.64718,0.978429046344757,0.6578
|
||||
11,0.9662233278274536,0.65672,0.9435468803405762,0.6703
|
||||
12,0.9290380508804321,0.67348,0.9357032821655273,0.6739
|
||||
13,0.8968851718902587,0.6826,0.9126621349334717,0.6813
|
||||
14,0.8746093656349182,0.69122,0.889759654712677,0.6948
|
||||
15,0.8559883063125611,0.69704,0.8787766233444214,0.6967
|
||||
16,0.8247946104431152,0.70706,0.8809524188995361,0.6961
|
||||
17,0.8109111586380005,0.71618,0.8809684686660767,0.6933
|
||||
18,0.7986607175636291,0.7161,0.8621552461624146,0.7016
|
||||
19,0.7699470348548889,0.72814,0.8589315553665161,0.7061
|
||||
20,0.7543044216918945,0.73224,0.8606150555610657,0.7066
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.0493293154907226,0.2425,1.7525886156082153,0.3748
|
||||
2,1.7320209578704835,0.3674,1.5615552814483642,0.4388
|
||||
3,1.598600202560425,0.4189,1.3815029901504516,0.4997
|
||||
4,1.5132850009155274,0.44868,1.3364418529510498,0.5207
|
||||
5,1.453274436149597,0.47022,1.2501229566574097,0.5498
|
||||
6,1.3984402522277832,0.49416,1.2111383409500123,0.5629
|
||||
7,1.3466207969284059,0.51412,1.1169193510055542,0.6038
|
||||
8,1.3081947090530395,0.52678,1.1443196024894715,0.5976
|
||||
9,1.2682366030120849,0.54628,1.0807610153198242,0.6125
|
||||
10,1.2392250786972046,0.55564,1.0270665140151978,0.6391
|
||||
11,1.2111315982818605,0.5651,1.017524287033081,0.6409
|
||||
12,1.1976571616363525,0.57106,0.9696040712356567,0.6622
|
||||
13,1.1811721408843994,0.57794,0.9810876142501831,0.6563
|
||||
14,1.16085556640625,0.58526,0.9509256943702697,0.6652
|
||||
15,1.152661986503601,0.58886,0.9255242247581482,0.6769
|
||||
16,1.1375279732131958,0.5958,0.9393579832077026,0.6742
|
||||
17,1.1235777826309203,0.59944,0.9380761775970459,0.6738
|
||||
18,1.125510050086975,0.6006,0.9149430327415466,0.6828
|
||||
19,1.107088499584198,0.60576,0.8883607380867005,0.6866
|
||||
20,1.0925483314132691,0.611,0.9142857453346253,0.6797
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.064392961997986,0.23846,1.7686689800262452,0.3761
|
||||
2,1.7451289974212647,0.3644,1.5356458026885986,0.4497
|
||||
3,1.588267235183716,0.42186,1.3689933862686157,0.5059
|
||||
4,1.519119600906372,0.44644,1.3433129314422607,0.5085
|
||||
5,1.4625473209762574,0.46948,1.277887591934204,0.5426
|
||||
6,1.4183696475982666,0.48196,1.2298621862411498,0.5589
|
||||
7,1.380669080886841,0.50072,1.257154960823059,0.5392
|
||||
8,1.344734386291504,0.51252,1.20913639087677,0.5659
|
||||
9,1.3033011206817626,0.53002,1.083999695968628,0.6162
|
||||
10,1.2634273228454589,0.5458,1.0644443402290344,0.6254
|
||||
11,1.2577572822189331,0.54942,1.030274422645569,0.6359
|
||||
12,1.2248346060943605,0.56484,1.0397949484825135,0.6295
|
||||
13,1.2027203491973877,0.57098,0.9911274933815002,0.6492
|
||||
14,1.187558348007202,0.57758,0.9643227214813233,0.6593
|
||||
15,1.1724095825195313,0.5828,0.9574517963409424,0.6634
|
||||
16,1.1600826649475098,0.5883,0.9531858426094055,0.6656
|
||||
17,1.148004737739563,0.59024,0.9419389196395874,0.6644
|
||||
18,1.135836477394104,0.59494,0.9213238637924195,0.6784
|
||||
19,1.117524945716858,0.60416,0.924640416431427,0.6748
|
||||
20,1.1149262449264525,0.60376,0.8988444912910462,0.6854
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.0648545223999024,0.23818,1.765181273651123,0.3742
|
||||
2,1.7573804076766968,0.362,1.535401229095459,0.4471
|
||||
3,1.611904190979004,0.41362,1.4302934280395507,0.4844
|
||||
4,1.525922943458557,0.44314,1.3388441865921021,0.5087
|
||||
5,1.467007992515564,0.46652,1.2719555583953857,0.5469
|
||||
6,1.4222887012481689,0.48284,1.248225646018982,0.5555
|
||||
7,1.373562057991028,0.50482,1.169153240585327,0.5856
|
||||
8,1.3381992527008058,0.51634,1.1345109245300293,0.6014
|
||||
9,1.3071355611419677,0.53108,1.103178568649292,0.6051
|
||||
10,1.265291601448059,0.54506,1.071216807079315,0.6169
|
||||
11,1.244923010597229,0.55354,1.068904607105255,0.6212
|
||||
12,1.2326332139205933,0.55886,1.055773473072052,0.6307
|
||||
13,1.2040435157012939,0.57052,0.9872216351509094,0.6528
|
||||
14,1.1909678496170044,0.57444,0.9747416945457459,0.6634
|
||||
15,1.16843085231781,0.58078,0.9478087964057922,0.6734
|
||||
16,1.1639410251617432,0.58364,0.9701937286376953,0.6655
|
||||
17,1.14788178440094,0.59224,0.947301069355011,0.6735
|
||||
18,1.1326425197601317,0.5989,0.9041819067955017,0.6868
|
||||
19,1.1362577977752686,0.59788,0.9516706992149353,0.6716
|
||||
20,1.1200445556259155,0.60232,0.9014985984802246,0.6861
|
||||
|
@@ -0,0 +1,37 @@
|
||||
seed,optimizer,augmentation,test_acc,robustness
|
||||
42,sgd,none,0.7017,"{'0.1': 0.6874, '0.2': 0.6539, '0.3': 0.573}"
|
||||
42,sgd,standard,0.6983,"{'0.1': 0.685, '0.2': 0.6393, '0.3': 0.5324}"
|
||||
42,sgd,aggressive,0.6529,"{'0.1': 0.6441, '0.2': 0.5905, '0.3': 0.5073}"
|
||||
42,adam,none,0.5754,"{'0.1': 0.5688, '0.2': 0.5225, '0.3': 0.461}"
|
||||
42,adam,standard,0.5012,"{'0.1': 0.5008, '0.2': 0.4696, '0.3': 0.3933}"
|
||||
42,adam,aggressive,0.4534,"{'0.1': 0.443, '0.2': 0.4074, '0.3': 0.3669}"
|
||||
123,sgd,none,0.7018,"{'0.1': 0.6907, '0.2': 0.6513, '0.3': 0.5808}"
|
||||
123,sgd,standard,0.6987,"{'0.1': 0.688, '0.2': 0.6378, '0.3': 0.5488}"
|
||||
123,sgd,aggressive,0.6736,"{'0.1': 0.6591, '0.2': 0.6077, '0.3': 0.5265}"
|
||||
123,adam,none,0.5414,"{'0.1': 0.5207, '0.2': 0.4721, '0.3': 0.3951}"
|
||||
123,adam,standard,0.4439,"{'0.1': 0.4509, '0.2': 0.4256, '0.3': 0.3567}"
|
||||
123,adam,aggressive,0.4519,"{'0.1': 0.4502, '0.2': 0.4266, '0.3': 0.3584}"
|
||||
999,sgd,none,0.6778,"{'0.1': 0.669, '0.2': 0.6288, '0.3': 0.5506}"
|
||||
999,sgd,standard,0.6961,"{'0.1': 0.6862, '0.2': 0.6491, '0.3': 0.5655}"
|
||||
999,sgd,aggressive,0.6623,"{'0.1': 0.6504, '0.2': 0.5934, '0.3': 0.5163}"
|
||||
999,adam,none,0.5477,"{'0.1': 0.5394, '0.2': 0.4983, '0.3': 0.4188}"
|
||||
999,adam,standard,0.4508,"{'0.1': 0.4455, '0.2': 0.3987, '0.3': 0.308}"
|
||||
999,adam,aggressive,0.4209,"{'0.1': 0.4268, '0.2': 0.4071, '0.3': 0.3368}"
|
||||
42,sgd,none,0.7013,"{'0.1': 0.6895, '0.2': 0.6449, '0.3': 0.5712}"
|
||||
42,sgd,standard,0.6791,"{'0.1': 0.6702, '0.2': 0.6315, '0.3': 0.5601}"
|
||||
42,sgd,aggressive,0.6703,"{'0.1': 0.6505, '0.2': 0.587, '0.3': 0.5143}"
|
||||
42,adam,none,0.5658,"{'0.1': 0.5575, '0.2': 0.5204, '0.3': 0.4498}"
|
||||
42,adam,standard,0.4394,"{'0.1': 0.4385, '0.2': 0.4069, '0.3': 0.3476}"
|
||||
42,adam,aggressive,0.45,"{'0.1': 0.4467, '0.2': 0.4168, '0.3': 0.3557}"
|
||||
123,sgd,none,0.7002,"{'0.1': 0.6902, '0.2': 0.6511, '0.3': 0.5781}"
|
||||
123,sgd,standard,0.6951,"{'0.1': 0.6833, '0.2': 0.6248, '0.3': 0.5406}"
|
||||
123,sgd,aggressive,0.6766,"{'0.1': 0.6661, '0.2': 0.6188, '0.3': 0.5369}"
|
||||
123,adam,none,0.4857,"{'0.1': 0.4851, '0.2': 0.4572, '0.3': 0.4191}"
|
||||
123,adam,standard,0.4536,"{'0.1': 0.4517, '0.2': 0.4216, '0.3': 0.3551}"
|
||||
123,adam,aggressive,0.4542,"{'0.1': 0.4568, '0.2': 0.4363, '0.3': 0.3909}"
|
||||
999,sgd,none,0.6961,"{'0.1': 0.6845, '0.2': 0.6509, '0.3': 0.5702}"
|
||||
999,sgd,standard,0.6896,"{'0.1': 0.6757, '0.2': 0.6251, '0.3': 0.5515}"
|
||||
999,sgd,aggressive,0.663,"{'0.1': 0.6542, '0.2': 0.6143, '0.3': 0.5326}"
|
||||
999,adam,none,0.5189,"{'0.1': 0.5138, '0.2': 0.4787, '0.3': 0.4115}"
|
||||
999,adam,standard,0.4934,"{'0.1': 0.4822, '0.2': 0.4312, '0.3': 0.34}"
|
||||
999,adam,aggressive,0.4039,"{'0.1': 0.4053, '0.2': 0.3814, '0.3': 0.3183}"
|
||||
|
@@ -0,0 +1,398 @@
|
||||
[
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.7017,
|
||||
"robustness": {
|
||||
"0.1": 0.6874,
|
||||
"0.2": 0.6539,
|
||||
"0.3": 0.573
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.6983,
|
||||
"robustness": {
|
||||
"0.1": 0.685,
|
||||
"0.2": 0.6393,
|
||||
"0.3": 0.5324
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.6529,
|
||||
"robustness": {
|
||||
"0.1": 0.6441,
|
||||
"0.2": 0.5905,
|
||||
"0.3": 0.5073
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.5754,
|
||||
"robustness": {
|
||||
"0.1": 0.5688,
|
||||
"0.2": 0.5225,
|
||||
"0.3": 0.461
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.5012,
|
||||
"robustness": {
|
||||
"0.1": 0.5008,
|
||||
"0.2": 0.4696,
|
||||
"0.3": 0.3933
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.4534,
|
||||
"robustness": {
|
||||
"0.1": 0.443,
|
||||
"0.2": 0.4074,
|
||||
"0.3": 0.3669
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.7018,
|
||||
"robustness": {
|
||||
"0.1": 0.6907,
|
||||
"0.2": 0.6513,
|
||||
"0.3": 0.5808
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.6987,
|
||||
"robustness": {
|
||||
"0.1": 0.688,
|
||||
"0.2": 0.6378,
|
||||
"0.3": 0.5488
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.6736,
|
||||
"robustness": {
|
||||
"0.1": 0.6591,
|
||||
"0.2": 0.6077,
|
||||
"0.3": 0.5265
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.5414,
|
||||
"robustness": {
|
||||
"0.1": 0.5207,
|
||||
"0.2": 0.4721,
|
||||
"0.3": 0.3951
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.4439,
|
||||
"robustness": {
|
||||
"0.1": 0.4509,
|
||||
"0.2": 0.4256,
|
||||
"0.3": 0.3567
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.4519,
|
||||
"robustness": {
|
||||
"0.1": 0.4502,
|
||||
"0.2": 0.4266,
|
||||
"0.3": 0.3584
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.6778,
|
||||
"robustness": {
|
||||
"0.1": 0.669,
|
||||
"0.2": 0.6288,
|
||||
"0.3": 0.5506
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.6961,
|
||||
"robustness": {
|
||||
"0.1": 0.6862,
|
||||
"0.2": 0.6491,
|
||||
"0.3": 0.5655
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.6623,
|
||||
"robustness": {
|
||||
"0.1": 0.6504,
|
||||
"0.2": 0.5934,
|
||||
"0.3": 0.5163
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.5477,
|
||||
"robustness": {
|
||||
"0.1": 0.5394,
|
||||
"0.2": 0.4983,
|
||||
"0.3": 0.4188
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.4508,
|
||||
"robustness": {
|
||||
"0.1": 0.4455,
|
||||
"0.2": 0.3987,
|
||||
"0.3": 0.308
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.4209,
|
||||
"robustness": {
|
||||
"0.1": 0.4268,
|
||||
"0.2": 0.4071,
|
||||
"0.3": 0.3368
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.7013,
|
||||
"robustness": {
|
||||
"0.1": 0.6895,
|
||||
"0.2": 0.6449,
|
||||
"0.3": 0.5712
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.6791,
|
||||
"robustness": {
|
||||
"0.1": 0.6702,
|
||||
"0.2": 0.6315,
|
||||
"0.3": 0.5601
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.6703,
|
||||
"robustness": {
|
||||
"0.1": 0.6505,
|
||||
"0.2": 0.587,
|
||||
"0.3": 0.5143
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.5658,
|
||||
"robustness": {
|
||||
"0.1": 0.5575,
|
||||
"0.2": 0.5204,
|
||||
"0.3": 0.4498
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.4394,
|
||||
"robustness": {
|
||||
"0.1": 0.4385,
|
||||
"0.2": 0.4069,
|
||||
"0.3": 0.3476
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.45,
|
||||
"robustness": {
|
||||
"0.1": 0.4467,
|
||||
"0.2": 0.4168,
|
||||
"0.3": 0.3557
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.7002,
|
||||
"robustness": {
|
||||
"0.1": 0.6902,
|
||||
"0.2": 0.6511,
|
||||
"0.3": 0.5781
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.6951,
|
||||
"robustness": {
|
||||
"0.1": 0.6833,
|
||||
"0.2": 0.6248,
|
||||
"0.3": 0.5406
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.6766,
|
||||
"robustness": {
|
||||
"0.1": 0.6661,
|
||||
"0.2": 0.6188,
|
||||
"0.3": 0.5369
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.4857,
|
||||
"robustness": {
|
||||
"0.1": 0.4851,
|
||||
"0.2": 0.4572,
|
||||
"0.3": 0.4191
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.4536,
|
||||
"robustness": {
|
||||
"0.1": 0.4517,
|
||||
"0.2": 0.4216,
|
||||
"0.3": 0.3551
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.4542,
|
||||
"robustness": {
|
||||
"0.1": 0.4568,
|
||||
"0.2": 0.4363,
|
||||
"0.3": 0.3909
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.6961,
|
||||
"robustness": {
|
||||
"0.1": 0.6845,
|
||||
"0.2": 0.6509,
|
||||
"0.3": 0.5702
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.6896,
|
||||
"robustness": {
|
||||
"0.1": 0.6757,
|
||||
"0.2": 0.6251,
|
||||
"0.3": 0.5515
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.663,
|
||||
"robustness": {
|
||||
"0.1": 0.6542,
|
||||
"0.2": 0.6143,
|
||||
"0.3": 0.5326
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.5189,
|
||||
"robustness": {
|
||||
"0.1": 0.5138,
|
||||
"0.2": 0.4787,
|
||||
"0.3": 0.4115
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.4934,
|
||||
"robustness": {
|
||||
"0.1": 0.4822,
|
||||
"0.2": 0.4312,
|
||||
"0.3": 0.34
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.4039,
|
||||
"robustness": {
|
||||
"0.1": 0.4053,
|
||||
"0.2": 0.3814,
|
||||
"0.3": 0.3183
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,7 @@
|
||||
from torchvision.datasets import CIFAR10, CIFAR100
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
ds10 = CIFAR10(root='data/', train=True, download=True)
|
||||
ds100 = CIFAR100(root='data/', train=True, download=True)
|
||||
|
||||
ds_c10c = tfds.load('cifar10_corrupted')
|
||||
@@ -0,0 +1,284 @@
|
||||
import os
|
||||
import argparse
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
import torchvision
|
||||
from torchvision import transforms
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import statsmodels.api as sm
|
||||
from statsmodels.formula.api import ols
|
||||
from argparse import Namespace
|
||||
|
||||
# simple cnn model definition
|
||||
# I looked a lot at https://github.com/giusarno/SimpleCNN/blob/master/examples/cifar10/themodel.py
|
||||
# before making this class, mostly because I was not aware of the `MaxPool2d` function
|
||||
|
||||
|
||||
class SimpleCNN(nn.Module):
|
||||
def __init__(self, num_classes=10):
|
||||
super(SimpleCNN, self).__init__()
|
||||
DROPOUT_RATE = 0.2
|
||||
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(),
|
||||
nn.MaxPool2d(2),
|
||||
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
|
||||
nn.MaxPool2d(2),
|
||||
)
|
||||
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Dropout(DROPOUT_RATE),
|
||||
nn.Linear(64 * 8 * 8, 128), nn.ReLU(),
|
||||
nn.Linear(128, num_classes),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
def get_data_loaders(batch_size, augmentation):
|
||||
# transform pipelines
|
||||
if augmentation == 'none':
|
||||
transform_train = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
elif augmentation == 'standard':
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
elif augmentation == 'aggressive':
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomRotation(15),
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.ColorJitter(brightness=0.2, contrast=0.2,
|
||||
saturation=0.2, hue=0.1),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
else:
|
||||
raise ValueError(f"unknown augmentation: {augmentation}")
|
||||
|
||||
transform_test = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
train_dataset = torchvision.datasets.CIFAR10(
|
||||
root='./data', train=True, download=True, transform=transform_train)
|
||||
test_dataset = torchvision.datasets.CIFAR10(
|
||||
root='./data', train=False, download=True, transform=transform_test)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
|
||||
test_loader = DataLoader(
|
||||
test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
|
||||
|
||||
return train_loader, test_loader
|
||||
|
||||
|
||||
# train for 1 epoch
|
||||
def train_one_epoch(model, optimizer, criterion, dataloader, device, aug=True):
|
||||
model.train()
|
||||
running_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
for inputs, targets in dataloader:
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
optimizer.zero_grad()
|
||||
|
||||
if aug:
|
||||
noisstd = np.random.uniform(0, 0.2)
|
||||
inputs = inputs + noisstd * torch.randn_like(inputs)
|
||||
# inputs = torch.clamp(inputs, 0.0, 1.0)
|
||||
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, targets)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item() * inputs.size(0)
|
||||
_, predicted = outputs.max(1)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
total += targets.size(0)
|
||||
|
||||
epoch_loss = running_loss / total
|
||||
epoch_acc = correct / total
|
||||
return epoch_loss, epoch_acc
|
||||
|
||||
|
||||
# eval on clean data
|
||||
def evaluate(model, criterion, dataloader, device):
|
||||
model.eval()
|
||||
running_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
with torch.no_grad():
|
||||
for inputs, targets in dataloader:
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, targets)
|
||||
|
||||
running_loss += loss.item() * inputs.size(0)
|
||||
_, predicted = outputs.max(1)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
total += targets.size(0)
|
||||
loss = running_loss / total
|
||||
acc = correct / total
|
||||
return loss, acc
|
||||
|
||||
|
||||
# eval robustness under gaussian noise
|
||||
def evaluate_robustness(model, dataloader, device, noise_std):
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
with torch.no_grad():
|
||||
for inputs, targets in dataloader:
|
||||
noisy_inputs = inputs + noise_std * torch.randn_like(inputs)
|
||||
# noisy_inputs = torch.clamp(noisy_inputs, 0.0, 1.0)
|
||||
noisy_inputs, targets = noisy_inputs.to(device), targets.to(device)
|
||||
|
||||
outputs = model(noisy_inputs)
|
||||
_, predicted = outputs.max(1)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
total += targets.size(0)
|
||||
acc = correct / total
|
||||
return acc
|
||||
|
||||
|
||||
def analyze_results(results_path='results.json'):
|
||||
with open(results_path) as f:
|
||||
results = json.load(f)
|
||||
df = pd.DataFrame(results)
|
||||
df.to_csv('analysis_results.csv', index=False)
|
||||
|
||||
# full ANOVA w/interaction
|
||||
model = ols('test_acc ~ C(optimizer) * C(augmentation)', data=df).fit()
|
||||
anova_table = sm.stats.anova_lm(model, typ=2)
|
||||
print('anova on test accuracy:')
|
||||
print(anova_table)
|
||||
|
||||
# composite label
|
||||
df['condition'] = df['optimizer'] + '_' + df['augmentation']
|
||||
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
colors = plt.cm.viridis(np.linspace(0.2, 0.8, len(df)))
|
||||
df.plot.bar(x='condition', y='test_acc', rot=45, color=colors, ax=ax)
|
||||
|
||||
|
||||
df.plot.bar(x='condition', y='test_acc', rot=45)
|
||||
plt.ylabel('test accuracy')
|
||||
# plt.tight_layout()
|
||||
|
||||
# only show every other tick label to avoid overcrowding
|
||||
tick_labels = ax.get_xticklabels()
|
||||
new_labels = [label.get_text() if i % 2 == 0 else "" for i, label in enumerate(tick_labels)]
|
||||
ax.set_xticklabels(new_labels)
|
||||
|
||||
|
||||
# ripped off the py docs --> viridis colormap for bars
|
||||
colors = plt.cm.viridis(np.linspace(0.2, 0.8, len(df)))
|
||||
ax = df.plot.bar(x='condition', y='test_acc', rot=45, color=colors)
|
||||
plt.ylabel('test accuracy')
|
||||
plt.tight_layout()
|
||||
plt.savefig('test_acc_comparison.png')
|
||||
print('saved plot to test_acc_comparison.png')
|
||||
|
||||
|
||||
def run_experiments(args):
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
results = []
|
||||
|
||||
optimizers = {
|
||||
'sgd': lambda params: optim.SGD(params, lr=args.lr, momentum=0.9, weight_decay=1e-3),
|
||||
'adam': lambda params: optim.Adam(params, lr=args.lr, weight_decay=1e-3)
|
||||
}
|
||||
augmentations = ['none', 'standard', 'aggressive']
|
||||
|
||||
seeds = [42, 123, 999]
|
||||
for seed in seeds:
|
||||
print("SEED", seed)
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
for opt_name in optimizers:
|
||||
for aug in augmentations:
|
||||
train_loader, test_loader = get_data_loaders(args.batch_size, aug)
|
||||
noise_levels = [0.1, 0.2, 0.3]
|
||||
model = SimpleCNN(num_classes=10).to(device)
|
||||
optimizer = optimizers[opt_name](model.parameters())
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
history = {
|
||||
'epoch': [], 'train_loss': [], 'train_acc': [],
|
||||
'test_loss': [], 'test_acc': []
|
||||
}
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
train_loss, train_acc = train_one_epoch(
|
||||
model, optimizer, criterion, train_loader, device)
|
||||
|
||||
test_loss, test_acc = evaluate(
|
||||
model, criterion, test_loader, device)
|
||||
|
||||
history['epoch'].append(epoch + 1)
|
||||
history['train_loss'].append(train_loss)
|
||||
history['train_acc'].append(train_acc)
|
||||
history['test_loss'].append(test_loss)
|
||||
history['test_acc'].append(test_acc)
|
||||
|
||||
print(f"[{opt_name}][{aug}][epoch {epoch + 1}] "
|
||||
f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, "
|
||||
f"test_acc={test_acc:.4f}")
|
||||
|
||||
robustness = {noise: evaluate_robustness(
|
||||
model, test_loader, device, noise) for noise in noise_levels}
|
||||
|
||||
pd.DataFrame(history).to_csv(
|
||||
f"analysis/history_{opt_name}_{aug}_{seed}.csv", index=False)
|
||||
|
||||
results.append({
|
||||
'seed': seed,
|
||||
'optimizer': opt_name,
|
||||
'augmentation': aug,
|
||||
'test_acc': test_acc,
|
||||
'robustness': robustness
|
||||
})
|
||||
|
||||
with open('results.json', 'w') as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print('saved results to results.json')
|
||||
|
||||
|
||||
# credit: I gave chatgpt a list of args and it made the arg parser for me
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--batch_size', type=int, default=128)
|
||||
parser.add_argument('--lr', type=float, default=0.01)
|
||||
parser.add_argument('--epochs', type=int, default=20)
|
||||
parser.add_argument('--analyze', action='store_true',
|
||||
help='run analysis on results')
|
||||
|
||||
args = parser.parse_args()
|
||||
print(json.dumps(vars(args)))
|
||||
print(args)
|
||||
# args = Namespace(batch_size=128, lr=0.01, epochs=20, analyze=False)
|
||||
# exit(1)
|
||||
|
||||
if args.analyze:
|
||||
analyze_results("combined_results.json")
|
||||
else:
|
||||
run_experiments(args)
|
||||
@@ -0,0 +1,19 @@
|
||||
seed,optimizer,augmentation,test_acc,robustness
|
||||
42,sgd,none,0.7088,"{'0.1': 0.6319, '0.2': 0.4336, '0.3': 0.2913}"
|
||||
42,sgd,standard,0.6859,"{'0.1': 0.5952, '0.2': 0.4019, '0.3': 0.2757}"
|
||||
42,sgd,aggressive,0.6536,"{'0.1': 0.5778, '0.2': 0.43, '0.3': 0.2943}"
|
||||
42,adam,none,0.5451,"{'0.1': 0.4221, '0.2': 0.2298, '0.3': 0.1545}"
|
||||
42,adam,standard,0.5101,"{'0.1': 0.454, '0.2': 0.2098, '0.3': 0.1324}"
|
||||
42,adam,aggressive,0.4427,"{'0.1': 0.4048, '0.2': 0.2461, '0.3': 0.1547}"
|
||||
123,sgd,none,0.6974,"{'0.1': 0.63, '0.2': 0.4452, '0.3': 0.312}"
|
||||
123,sgd,standard,0.6674,"{'0.1': 0.6252, '0.2': 0.4146, '0.3': 0.2764}"
|
||||
123,sgd,aggressive,0.6691,"{'0.1': 0.6179, '0.2': 0.4691, '0.3': 0.3423}"
|
||||
123,adam,none,0.6049,"{'0.1': 0.4685, '0.2': 0.3387, '0.3': 0.2378}"
|
||||
123,adam,standard,0.4654,"{'0.1': 0.4071, '0.2': 0.3073, '0.3': 0.2341}"
|
||||
123,adam,aggressive,0.5096,"{'0.1': 0.4624, '0.2': 0.3219, '0.3': 0.2159}"
|
||||
999,sgd,none,0.7058,"{'0.1': 0.6252, '0.2': 0.3848, '0.3': 0.2276}"
|
||||
999,sgd,standard,0.6861,"{'0.1': 0.6002, '0.2': 0.4184, '0.3': 0.2986}"
|
||||
999,sgd,aggressive,0.6595,"{'0.1': 0.5775, '0.2': 0.4165, '0.3': 0.2899}"
|
||||
999,adam,none,0.5573,"{'0.1': 0.4562, '0.2': 0.293, '0.3': 0.2167}"
|
||||
999,adam,standard,0.4835,"{'0.1': 0.4136, '0.2': 0.2221, '0.3': 0.1548}"
|
||||
999,adam,aggressive,0.5123,"{'0.1': 0.449, '0.2': 0.2571, '0.3': 0.1658}"
|
||||
|
@@ -0,0 +1,19 @@
|
||||
seed,optimizer,augmentation,test_acc,robustness
|
||||
42,sgd,none,0.7088,"{'0.1': 0.6319, '0.2': 0.4336, '0.3': 0.2913}"
|
||||
42,sgd,standard,0.6859,"{'0.1': 0.5952, '0.2': 0.4019, '0.3': 0.2757}"
|
||||
42,sgd,aggressive,0.6536,"{'0.1': 0.5778, '0.2': 0.43, '0.3': 0.2943}"
|
||||
42,adam,none,0.5451,"{'0.1': 0.4221, '0.2': 0.2298, '0.3': 0.1545}"
|
||||
42,adam,standard,0.5101,"{'0.1': 0.454, '0.2': 0.2098, '0.3': 0.1324}"
|
||||
42,adam,aggressive,0.4427,"{'0.1': 0.4048, '0.2': 0.2461, '0.3': 0.1547}"
|
||||
123,sgd,none,0.6974,"{'0.1': 0.63, '0.2': 0.4452, '0.3': 0.312}"
|
||||
123,sgd,standard,0.6674,"{'0.1': 0.6252, '0.2': 0.4146, '0.3': 0.2764}"
|
||||
123,sgd,aggressive,0.6691,"{'0.1': 0.6179, '0.2': 0.4691, '0.3': 0.3423}"
|
||||
123,adam,none,0.6049,"{'0.1': 0.4685, '0.2': 0.3387, '0.3': 0.2378}"
|
||||
123,adam,standard,0.4654,"{'0.1': 0.4071, '0.2': 0.3073, '0.3': 0.2341}"
|
||||
123,adam,aggressive,0.5096,"{'0.1': 0.4624, '0.2': 0.3219, '0.3': 0.2159}"
|
||||
999,sgd,none,0.7058,"{'0.1': 0.6252, '0.2': 0.3848, '0.3': 0.2276}"
|
||||
999,sgd,standard,0.6861,"{'0.1': 0.6002, '0.2': 0.4184, '0.3': 0.2986}"
|
||||
999,sgd,aggressive,0.6595,"{'0.1': 0.5775, '0.2': 0.4165, '0.3': 0.2899}"
|
||||
999,adam,none,0.5573,"{'0.1': 0.4562, '0.2': 0.293, '0.3': 0.2167}"
|
||||
999,adam,standard,0.4835,"{'0.1': 0.4136, '0.2': 0.2221, '0.3': 0.1548}"
|
||||
999,adam,aggressive,0.5123,"{'0.1': 0.449, '0.2': 0.2571, '0.3': 0.1658}"
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.8846732257843017,0.3103,1.5956477916717529,0.4208
|
||||
2,1.6917855532073975,0.38452,1.5259598985671996,0.439
|
||||
3,1.6328186359024048,0.40266,1.5264780994415283,0.4508
|
||||
4,1.5851124533462524,0.42526,1.424309602355957,0.4878
|
||||
5,1.5511348150634765,0.43632,1.4553828117370606,0.4854
|
||||
6,1.5411012967681885,0.44224,1.5167289329528808,0.4561
|
||||
7,1.5281506721115112,0.4473,1.3984178335189819,0.5015
|
||||
8,1.5113245751190185,0.4561,1.3988324977874755,0.5007
|
||||
9,1.5111655992889403,0.45554,1.4732477207183838,0.4733
|
||||
10,1.498042034225464,0.46026,1.3725576107025146,0.5096
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.0559358141326904,0.24472,1.828050934791565,0.3374
|
||||
2,1.857869116783142,0.3182,1.6779026025772095,0.3909
|
||||
3,1.7681750354385375,0.35418,1.6347509387969972,0.4094
|
||||
4,1.7275719815826416,0.37254,1.6000082025527953,0.4225
|
||||
5,1.699682864151001,0.37816,1.571951426887512,0.4362
|
||||
6,1.6953852003860475,0.38134,1.582085866355896,0.4284
|
||||
7,1.6721481609344482,0.392,1.584311390686035,0.4267
|
||||
8,1.6576726916885376,0.39738,1.5268918621063232,0.443
|
||||
9,1.6475247369384765,0.402,1.5393544744491576,0.4441
|
||||
10,1.6505113174819945,0.39736,1.5213500274658203,0.4427
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.9911173669052125,0.27896,1.6828551885604859,0.3994
|
||||
2,1.6908308898162843,0.38222,1.6040934282302857,0.4364
|
||||
3,1.6209180907440186,0.41298,1.4993294555664063,0.459
|
||||
4,1.591390087928772,0.42586,1.5083149290084839,0.4616
|
||||
5,1.5589105658340454,0.43662,1.428050018119812,0.4809
|
||||
6,1.540269641456604,0.44422,1.4299810447692871,0.4858
|
||||
7,1.5332854122543336,0.44906,1.3867164831161498,0.5119
|
||||
8,1.5187089794921875,0.4558,1.3682811828613282,0.5136
|
||||
9,1.5207957489013673,0.45552,1.4490770374298096,0.4736
|
||||
10,1.504308384437561,0.4596,1.3697486953735352,0.5123
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.6660349139404298,0.39084,1.3631993772506714,0.5092
|
||||
2,1.306575980796814,0.53194,1.2523228994369506,0.5514
|
||||
3,1.189557453918457,0.57828,1.2570601411819458,0.5594
|
||||
4,1.122563935699463,0.60502,1.187407196044922,0.5902
|
||||
5,1.0689283304977417,0.6233,1.2376863445281983,0.5659
|
||||
6,1.0334041289138793,0.6345,1.170531530189514,0.592
|
||||
7,0.9959703926849365,0.64648,1.1732149269104004,0.5901
|
||||
8,0.9762340403366089,0.65428,1.185727474975586,0.5919
|
||||
9,0.962807445487976,0.65934,1.1849331123352052,0.5937
|
||||
10,0.9365197282409667,0.66622,1.1537712555885316,0.6049
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.7286399437713622,0.37382,1.4246170207977296,0.4816
|
||||
2,1.3806899689483643,0.50334,1.3549695669174195,0.5164
|
||||
3,1.3021614087677003,0.53534,1.332729033279419,0.5202
|
||||
4,1.2415979095458984,0.55796,1.2883787483215332,0.542
|
||||
5,1.1996409232330323,0.5716,1.2885587697982788,0.5404
|
||||
6,1.1598682831573486,0.58466,1.2654446369171142,0.5521
|
||||
7,1.1435769744873048,0.59348,1.2620088846206665,0.5593
|
||||
8,1.122702364425659,0.60254,1.2513995946884156,0.5613
|
||||
9,1.0767982495117188,0.6175,1.2432810546875,0.5644
|
||||
10,1.0651884030914307,0.6178,1.321402802658081,0.5451
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.7729383182525635,0.35124,1.5228908493041993,0.4427
|
||||
2,1.430595055770874,0.482,1.38000087184906,0.5018
|
||||
3,1.32311696308136,0.52478,1.313341820716858,0.5259
|
||||
4,1.2678979627990723,0.54632,1.2857267393112182,0.5393
|
||||
5,1.2276857889556885,0.56124,1.299703761291504,0.5393
|
||||
6,1.1998348657989502,0.57262,1.2608893209457397,0.5565
|
||||
7,1.1683492805480957,0.58382,1.27673601436615,0.5546
|
||||
8,1.157327294807434,0.5878,1.2376579500198364,0.5654
|
||||
9,1.131808783454895,0.59764,1.2557779052734375,0.5572
|
||||
10,1.123384162902832,0.60062,1.2491890354156494,0.5573
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.8813365224456786,0.30172,1.6416668697357177,0.3985
|
||||
2,1.6806913624954223,0.37672,1.5487596210479737,0.4274
|
||||
3,1.6351341115570068,0.40124,1.5625509996414184,0.4249
|
||||
4,1.6064803477859497,0.41056,1.511009596824646,0.4455
|
||||
5,1.5825084410476684,0.42064,1.5225824378967285,0.4431
|
||||
6,1.5676781223678589,0.42982,1.4784889535903931,0.4585
|
||||
7,1.5567302836608887,0.43112,1.4934770944595337,0.4557
|
||||
8,1.5449237117004395,0.43588,1.463665308380127,0.4715
|
||||
9,1.5365594310760498,0.44164,1.4671966415405273,0.4741
|
||||
10,1.5271776266098023,0.44406,1.4860247068405152,0.4654
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.9121077362060548,0.28534,1.6980935123443603,0.3717
|
||||
2,1.634671248512268,0.3994,1.5801765157699585,0.4126
|
||||
3,1.5572229358291625,0.43102,1.4792659259796141,0.4583
|
||||
4,1.5200784818267823,0.44598,1.4566459318161011,0.4732
|
||||
5,1.4923802797317505,0.4583,1.3978405960083007,0.4928
|
||||
6,1.4686766564178466,0.46806,1.3776366807937621,0.503
|
||||
7,1.4700688259124757,0.46884,1.3533947219848632,0.5163
|
||||
8,1.4470393926620484,0.4773,1.345004845237732,0.5178
|
||||
9,1.4390167654800414,0.47734,1.3564238325119018,0.5147
|
||||
10,1.4347680015182496,0.48004,1.3694934066772462,0.5101
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.8785975454711914,0.30716,1.543831322669983,0.4375
|
||||
2,1.6268331103515625,0.40026,1.5018853443145752,0.4532
|
||||
3,1.5749516596221924,0.42014,1.5034992160797118,0.452
|
||||
4,1.5602307732009888,0.42962,1.4366535945892334,0.477
|
||||
5,1.5448442478561402,0.43214,1.439453507232666,0.4691
|
||||
6,1.5286526789474488,0.43836,1.4527703874588012,0.4671
|
||||
7,1.5160772939300537,0.4433,1.4506785593032836,0.4667
|
||||
8,1.5187876448822022,0.44144,1.4333687414169312,0.4782
|
||||
9,1.5096457154083252,0.44728,1.4133702894210816,0.4894
|
||||
10,1.495245510482788,0.45166,1.411236897277832,0.4835
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.0651948305892947,0.24042,1.729777430343628,0.3836
|
||||
2,1.7303296995544433,0.37548,1.49167406539917,0.4667
|
||||
3,1.5553171773147583,0.4367,1.3456432037353516,0.5205
|
||||
4,1.4635125664901734,0.47256,1.2475262697219849,0.5611
|
||||
5,1.3722680527114868,0.50766,1.1914715614318847,0.5746
|
||||
6,1.3082861290740966,0.53284,1.121172977924347,0.5981
|
||||
7,1.2505797283554076,0.5567,1.1289072477340698,0.5922
|
||||
8,1.2003705361175536,0.57516,1.0124303802490235,0.6468
|
||||
9,1.1683612518310547,0.58702,0.9980959354400635,0.65
|
||||
10,1.1325882072067261,0.6023,0.9680170437812805,0.6691
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.0705010092163088,0.22946,1.7674200216293334,0.3802
|
||||
2,1.7432165607452392,0.36844,1.533634359550476,0.4416
|
||||
3,1.5778629602050782,0.42886,1.3606201539993286,0.5097
|
||||
4,1.4700220097351073,0.46802,1.2655291788101197,0.5451
|
||||
5,1.3893862116241456,0.50128,1.204566476535797,0.5692
|
||||
6,1.3244363966751098,0.5255,1.2020204118728637,0.5744
|
||||
7,1.2688888247299195,0.54674,1.0509296971321105,0.6334
|
||||
8,1.2191224797821045,0.56592,1.0520767385482788,0.6286
|
||||
9,1.1705705532073976,0.58452,1.0169856172561647,0.6394
|
||||
10,1.1425927544021606,0.59578,0.9944705787658692,0.6536
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.028081403198242,0.25312,1.7972938034057617,0.3576
|
||||
2,1.6919663692855835,0.38654,1.4519519706726074,0.4827
|
||||
3,1.5530475779342652,0.43992,1.3310353057861328,0.5251
|
||||
4,1.4540857498931885,0.47812,1.292456304550171,0.5413
|
||||
5,1.3654178101348877,0.51218,1.1448781085014343,0.6
|
||||
6,1.3064117618942261,0.53414,1.1193419974327088,0.6077
|
||||
7,1.2500248126220703,0.5551,1.0644649827957153,0.6217
|
||||
8,1.2079097275543214,0.5713,1.0305480089187622,0.6358
|
||||
9,1.1659785708236694,0.58452,0.9741595977783203,0.6535
|
||||
10,1.1451263586425782,0.59474,0.9785669965744018,0.6595
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.9007139762496947,0.30904,1.5374012899398803,0.4448
|
||||
2,1.4438594173431396,0.48224,1.360331530570984,0.5054
|
||||
3,1.2486270139312745,0.5548,1.1819912509918213,0.5701
|
||||
4,1.1134616537475586,0.60458,1.0598442848205567,0.6256
|
||||
5,0.9943869771957398,0.6509,0.9994649070739746,0.6511
|
||||
6,0.9061599386596679,0.68082,0.9807863761901855,0.6526
|
||||
7,0.832235673122406,0.70842,0.9603629438400269,0.6641
|
||||
8,0.7533782648849487,0.7388,0.9113856033325195,0.6857
|
||||
9,0.6814181573486328,0.76166,0.9031674418449401,0.6972
|
||||
10,0.6241655448532104,0.78184,0.8935547742843628,0.6974
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.9639052035522462,0.28114,1.6098346630096436,0.4097
|
||||
2,1.4680433963775634,0.46986,1.302504965209961,0.5284
|
||||
3,1.2549403902435303,0.5519,1.225121039199829,0.5624
|
||||
4,1.1094685572624206,0.60738,1.0919673002243042,0.612
|
||||
5,0.9939155144119263,0.6497,0.9837436180114746,0.659
|
||||
6,0.8942083934020996,0.6852,0.957837422657013,0.6612
|
||||
7,0.8110888798904419,0.71492,0.9374553295135498,0.6761
|
||||
8,0.744535570487976,0.7387,0.9072779047966003,0.6925
|
||||
9,0.6828387829208374,0.76086,0.8724043285369874,0.7033
|
||||
10,0.6119843885612488,0.78486,0.8666944108009338,0.7088
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.9422726140975952,0.28742,1.637837126159668,0.4084
|
||||
2,1.4907189490509034,0.46088,1.3707801050186157,0.5124
|
||||
3,1.2916603052902222,0.5373,1.1907301538467407,0.5763
|
||||
4,1.1448502359771728,0.59362,1.0813807232856751,0.6132
|
||||
5,1.0147696270942688,0.64488,0.9749312539100647,0.658
|
||||
6,0.9084624612808228,0.68092,0.971125221824646,0.6546
|
||||
7,0.8264397340011597,0.70914,0.9333860067367554,0.6724
|
||||
8,0.7552365953445435,0.73546,0.9200770247459411,0.6771
|
||||
9,0.6857615605545044,0.76032,0.8947711204528809,0.6944
|
||||
10,0.6202431632995605,0.78266,0.8700129133224487,0.7058
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.029623507652283,0.2529,1.688811130142212,0.3927
|
||||
2,1.6488482897186278,0.40018,1.4374095891952514,0.4752
|
||||
3,1.4772799210739136,0.46326,1.3251022443771363,0.5178
|
||||
4,1.3625635675048828,0.50938,1.2472985363006592,0.5517
|
||||
5,1.2698390966033934,0.54114,1.1421482133865357,0.5942
|
||||
6,1.1823028423309325,0.57602,1.0808060108184814,0.6158
|
||||
7,1.118297325630188,0.60142,0.9704109483718872,0.6591
|
||||
8,1.0673625009155274,0.62172,0.9435201133728027,0.6721
|
||||
9,1.0345378282928466,0.6332,0.9358344539642334,0.6687
|
||||
10,0.9860974870300293,0.65256,0.950194506072998,0.6674
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.009225925979614,0.26266,1.6814233026504517,0.3895
|
||||
2,1.6603656655120849,0.39688,1.4227338985443114,0.4802
|
||||
3,1.4855073391342164,0.45954,1.3364979927062988,0.5176
|
||||
4,1.3897099987411499,0.49994,1.2245563619613646,0.562
|
||||
5,1.2903345095443726,0.5355,1.2043679784774781,0.5667
|
||||
6,1.2084671304321288,0.57048,1.0968135701179504,0.6135
|
||||
7,1.1449422802734375,0.5905,1.0201958515167235,0.6367
|
||||
8,1.0922849000167847,0.61064,0.9836806530952453,0.6468
|
||||
9,1.0471580500793456,0.62822,0.9301959774017334,0.6716
|
||||
10,1.0136299449157715,0.64068,0.9022101957321167,0.6859
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.959310048789978,0.28234,1.5769516773223877,0.4368
|
||||
2,1.5942330165863037,0.42066,1.4556853477478027,0.4914
|
||||
3,1.443750505142212,0.47874,1.316703634262085,0.5271
|
||||
4,1.3386957333755494,0.51598,1.1774909160614013,0.5831
|
||||
5,1.2375807570266724,0.55794,1.0962469652175904,0.6125
|
||||
6,1.164957059440613,0.58492,1.0318147755622864,0.6338
|
||||
7,1.102748593158722,0.60784,0.9843499254226684,0.6612
|
||||
8,1.0541843451309205,0.62636,0.9331276074409485,0.6778
|
||||
9,1.0117439514541626,0.64112,0.8982747142791748,0.6885
|
||||
10,0.982185898399353,0.65222,0.9087156764984131,0.6861
|
||||
|
@@ -0,0 +1,60 @@
|
||||
[sgd][none][epoch 0] train_loss=1.9493, train_acc=0.2892, test_acc=0.4237
|
||||
[sgd][none][epoch 1] train_loss=1.4806, train_acc=0.4651, test_acc=0.5235
|
||||
[sgd][none][epoch 2] train_loss=1.2920, train_acc=0.5367, test_acc=0.5619
|
||||
[sgd][none][epoch 3] train_loss=1.1722, train_acc=0.5809, test_acc=0.6119
|
||||
[sgd][none][epoch 4] train_loss=1.0507, train_acc=0.6286, test_acc=0.6334
|
||||
[sgd][none][epoch 5] train_loss=0.9572, train_acc=0.6634, test_acc=0.6452
|
||||
[sgd][none][epoch 6] train_loss=0.8812, train_acc=0.6916, test_acc=0.6703
|
||||
[sgd][none][epoch 7] train_loss=0.7986, train_acc=0.7200, test_acc=0.6722
|
||||
[sgd][none][epoch 8] train_loss=0.7448, train_acc=0.7412, test_acc=0.6824
|
||||
[sgd][none][epoch 9] train_loss=0.6798, train_acc=0.7641, test_acc=0.6843
|
||||
[sgd][standard][epoch 0] train_loss=2.0006, train_acc=0.2638, test_acc=0.4103
|
||||
[sgd][standard][epoch 1] train_loss=1.6251, train_acc=0.4074, test_acc=0.4890
|
||||
[sgd][standard][epoch 2] train_loss=1.4750, train_acc=0.4641, test_acc=0.5426
|
||||
[sgd][standard][epoch 3] train_loss=1.3654, train_acc=0.5054, test_acc=0.5678
|
||||
[sgd][standard][epoch 4] train_loss=1.2646, train_acc=0.5472, test_acc=0.6111
|
||||
[sgd][standard][epoch 5] train_loss=1.1843, train_acc=0.5760, test_acc=0.6166
|
||||
[sgd][standard][epoch 6] train_loss=1.1222, train_acc=0.5997, test_acc=0.6571
|
||||
[sgd][standard][epoch 7] train_loss=1.0737, train_acc=0.6188, test_acc=0.6665
|
||||
[sgd][standard][epoch 8] train_loss=1.0308, train_acc=0.6354, test_acc=0.6872
|
||||
[sgd][standard][epoch 9] train_loss=0.9978, train_acc=0.6465, test_acc=0.6807
|
||||
[sgd][aggressive][epoch 0] train_loss=2.0480, train_acc=0.2484, test_acc=0.3840
|
||||
[sgd][aggressive][epoch 1] train_loss=1.7163, train_acc=0.3802, test_acc=0.4501
|
||||
[sgd][aggressive][epoch 2] train_loss=1.5662, train_acc=0.4333, test_acc=0.5033
|
||||
[sgd][aggressive][epoch 3] train_loss=1.4807, train_acc=0.4668, test_acc=0.5330
|
||||
[sgd][aggressive][epoch 4] train_loss=1.4095, train_acc=0.4943, test_acc=0.5762
|
||||
[sgd][aggressive][epoch 5] train_loss=1.3395, train_acc=0.5195, test_acc=0.5879
|
||||
[sgd][aggressive][epoch 6] train_loss=1.2735, train_acc=0.5444, test_acc=0.6154
|
||||
[sgd][aggressive][epoch 7] train_loss=1.2203, train_acc=0.5677, test_acc=0.6368
|
||||
[sgd][aggressive][epoch 8] train_loss=1.1891, train_acc=0.5792, test_acc=0.6300
|
||||
[sgd][aggressive][epoch 9] train_loss=1.1479, train_acc=0.5907, test_acc=0.6630
|
||||
[adam][none][epoch 0] train_loss=1.7509, train_acc=0.3614, test_acc=0.4276
|
||||
[adam][none][epoch 1] train_loss=1.4346, train_acc=0.4818, test_acc=0.4860
|
||||
[adam][none][epoch 2] train_loss=1.3425, train_acc=0.5193, test_acc=0.5122
|
||||
[adam][none][epoch 3] train_loss=1.2968, train_acc=0.5353, test_acc=0.5197
|
||||
[adam][none][epoch 4] train_loss=1.2610, train_acc=0.5499, test_acc=0.5428
|
||||
[adam][none][epoch 5] train_loss=1.2298, train_acc=0.5618, test_acc=0.5206
|
||||
[adam][none][epoch 6] train_loss=1.2102, train_acc=0.5682, test_acc=0.5455
|
||||
[adam][none][epoch 7] train_loss=1.1824, train_acc=0.5800, test_acc=0.5495
|
||||
[adam][none][epoch 8] train_loss=1.1591, train_acc=0.5886, test_acc=0.5656
|
||||
[adam][none][epoch 9] train_loss=1.1332, train_acc=0.5972, test_acc=0.5696
|
||||
[adam][standard][epoch 0] train_loss=1.9005, train_acc=0.3018, test_acc=0.4193
|
||||
[adam][standard][epoch 1] train_loss=1.6180, train_acc=0.4022, test_acc=0.4547
|
||||
[adam][standard][epoch 2] train_loss=1.5576, train_acc=0.4308, test_acc=0.4751
|
||||
[adam][standard][epoch 3] train_loss=1.5089, train_acc=0.4519, test_acc=0.4908
|
||||
[adam][standard][epoch 4] train_loss=1.4817, train_acc=0.4578, test_acc=0.4807
|
||||
[adam][standard][epoch 5] train_loss=1.4661, train_acc=0.4690, test_acc=0.4925
|
||||
[adam][standard][epoch 6] train_loss=1.4498, train_acc=0.4750, test_acc=0.5123
|
||||
[adam][standard][epoch 7] train_loss=1.4318, train_acc=0.4831, test_acc=0.4820
|
||||
[adam][standard][epoch 8] train_loss=1.4296, train_acc=0.4812, test_acc=0.5210
|
||||
[adam][standard][epoch 9] train_loss=1.4231, train_acc=0.4860, test_acc=0.5161
|
||||
[adam][aggressive][epoch 0] train_loss=1.9556, train_acc=0.2839, test_acc=0.3976
|
||||
[adam][aggressive][epoch 1] train_loss=1.7166, train_acc=0.3748, test_acc=0.4414
|
||||
[adam][aggressive][epoch 2] train_loss=1.6507, train_acc=0.4009, test_acc=0.4486
|
||||
[adam][aggressive][epoch 3] train_loss=1.6179, train_acc=0.4119, test_acc=0.4693
|
||||
[adam][aggressive][epoch 4] train_loss=1.5985, train_acc=0.4178, test_acc=0.4676
|
||||
[adam][aggressive][epoch 5] train_loss=1.5799, train_acc=0.4264, test_acc=0.4788
|
||||
[adam][aggressive][epoch 6] train_loss=1.5763, train_acc=0.4274, test_acc=0.4759
|
||||
[adam][aggressive][epoch 7] train_loss=1.5635, train_acc=0.4340, test_acc=0.4687
|
||||
[adam][aggressive][epoch 8] train_loss=1.5546, train_acc=0.4359, test_acc=0.4992
|
||||
[adam][aggressive][epoch 9] train_loss=1.5463, train_acc=0.4410, test_acc=0.4831
|
||||
@@ -0,0 +1,200 @@
|
||||
[
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.7088,
|
||||
"robustness": {
|
||||
"0.1": 0.6319,
|
||||
"0.2": 0.4336,
|
||||
"0.3": 0.2913
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.6859,
|
||||
"robustness": {
|
||||
"0.1": 0.5952,
|
||||
"0.2": 0.4019,
|
||||
"0.3": 0.2757
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.6536,
|
||||
"robustness": {
|
||||
"0.1": 0.5778,
|
||||
"0.2": 0.43,
|
||||
"0.3": 0.2943
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.5451,
|
||||
"robustness": {
|
||||
"0.1": 0.4221,
|
||||
"0.2": 0.2298,
|
||||
"0.3": 0.1545
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.5101,
|
||||
"robustness": {
|
||||
"0.1": 0.454,
|
||||
"0.2": 0.2098,
|
||||
"0.3": 0.1324
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.4427,
|
||||
"robustness": {
|
||||
"0.1": 0.4048,
|
||||
"0.2": 0.2461,
|
||||
"0.3": 0.1547
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.6974,
|
||||
"robustness": {
|
||||
"0.1": 0.63,
|
||||
"0.2": 0.4452,
|
||||
"0.3": 0.312
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.6674,
|
||||
"robustness": {
|
||||
"0.1": 0.6252,
|
||||
"0.2": 0.4146,
|
||||
"0.3": 0.2764
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.6691,
|
||||
"robustness": {
|
||||
"0.1": 0.6179,
|
||||
"0.2": 0.4691,
|
||||
"0.3": 0.3423
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.6049,
|
||||
"robustness": {
|
||||
"0.1": 0.4685,
|
||||
"0.2": 0.3387,
|
||||
"0.3": 0.2378
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.4654,
|
||||
"robustness": {
|
||||
"0.1": 0.4071,
|
||||
"0.2": 0.3073,
|
||||
"0.3": 0.2341
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.5096,
|
||||
"robustness": {
|
||||
"0.1": 0.4624,
|
||||
"0.2": 0.3219,
|
||||
"0.3": 0.2159
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.7058,
|
||||
"robustness": {
|
||||
"0.1": 0.6252,
|
||||
"0.2": 0.3848,
|
||||
"0.3": 0.2276
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.6861,
|
||||
"robustness": {
|
||||
"0.1": 0.6002,
|
||||
"0.2": 0.4184,
|
||||
"0.3": 0.2986
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.6595,
|
||||
"robustness": {
|
||||
"0.1": 0.5775,
|
||||
"0.2": 0.4165,
|
||||
"0.3": 0.2899
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.5573,
|
||||
"robustness": {
|
||||
"0.1": 0.4562,
|
||||
"0.2": 0.293,
|
||||
"0.3": 0.2167
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.4835,
|
||||
"robustness": {
|
||||
"0.1": 0.4136,
|
||||
"0.2": 0.2221,
|
||||
"0.3": 0.1548
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.5123,
|
||||
"robustness": {
|
||||
"0.1": 0.449,
|
||||
"0.2": 0.2571,
|
||||
"0.3": 0.1658
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,200 @@
|
||||
[
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.7088,
|
||||
"robustness": {
|
||||
"0.1": 0.6319,
|
||||
"0.2": 0.4336,
|
||||
"0.3": 0.2913
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.6859,
|
||||
"robustness": {
|
||||
"0.1": 0.5952,
|
||||
"0.2": 0.4019,
|
||||
"0.3": 0.2757
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.6536,
|
||||
"robustness": {
|
||||
"0.1": 0.5778,
|
||||
"0.2": 0.43,
|
||||
"0.3": 0.2943
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.5451,
|
||||
"robustness": {
|
||||
"0.1": 0.4221,
|
||||
"0.2": 0.2298,
|
||||
"0.3": 0.1545
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.5101,
|
||||
"robustness": {
|
||||
"0.1": 0.454,
|
||||
"0.2": 0.2098,
|
||||
"0.3": 0.1324
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.4427,
|
||||
"robustness": {
|
||||
"0.1": 0.4048,
|
||||
"0.2": 0.2461,
|
||||
"0.3": 0.1547
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.6974,
|
||||
"robustness": {
|
||||
"0.1": 0.63,
|
||||
"0.2": 0.4452,
|
||||
"0.3": 0.312
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.6674,
|
||||
"robustness": {
|
||||
"0.1": 0.6252,
|
||||
"0.2": 0.4146,
|
||||
"0.3": 0.2764
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.6691,
|
||||
"robustness": {
|
||||
"0.1": 0.6179,
|
||||
"0.2": 0.4691,
|
||||
"0.3": 0.3423
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.6049,
|
||||
"robustness": {
|
||||
"0.1": 0.4685,
|
||||
"0.2": 0.3387,
|
||||
"0.3": 0.2378
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.4654,
|
||||
"robustness": {
|
||||
"0.1": 0.4071,
|
||||
"0.2": 0.3073,
|
||||
"0.3": 0.2341
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.5096,
|
||||
"robustness": {
|
||||
"0.1": 0.4624,
|
||||
"0.2": 0.3219,
|
||||
"0.3": 0.2159
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.7058,
|
||||
"robustness": {
|
||||
"0.1": 0.6252,
|
||||
"0.2": 0.3848,
|
||||
"0.3": 0.2276
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.6861,
|
||||
"robustness": {
|
||||
"0.1": 0.6002,
|
||||
"0.2": 0.4184,
|
||||
"0.3": 0.2986
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.6595,
|
||||
"robustness": {
|
||||
"0.1": 0.5775,
|
||||
"0.2": 0.4165,
|
||||
"0.3": 0.2899
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.5573,
|
||||
"robustness": {
|
||||
"0.1": 0.4562,
|
||||
"0.2": 0.293,
|
||||
"0.3": 0.2167
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.4835,
|
||||
"robustness": {
|
||||
"0.1": 0.4136,
|
||||
"0.2": 0.2221,
|
||||
"0.3": 0.1548
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.5123,
|
||||
"robustness": {
|
||||
"0.1": 0.449,
|
||||
"0.2": 0.2571,
|
||||
"0.3": 0.1658
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,62 @@
|
||||
[
|
||||
{
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.6843,
|
||||
"robustness": {
|
||||
"0.1": 0.618,
|
||||
"0.2": 0.4442,
|
||||
"0.3": 0.3226
|
||||
}
|
||||
},
|
||||
{
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.6807,
|
||||
"robustness": {
|
||||
"0.1": 0.5634,
|
||||
"0.2": 0.379,
|
||||
"0.3": 0.2741
|
||||
}
|
||||
},
|
||||
{
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.663,
|
||||
"robustness": {
|
||||
"0.1": 0.5884,
|
||||
"0.2": 0.4499,
|
||||
"0.3": 0.3406
|
||||
}
|
||||
},
|
||||
{
|
||||
"optimizer": "adam",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.5696,
|
||||
"robustness": {
|
||||
"0.1": 0.4816,
|
||||
"0.2": 0.3036,
|
||||
"0.3": 0.2133
|
||||
}
|
||||
},
|
||||
{
|
||||
"optimizer": "adam",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.5161,
|
||||
"robustness": {
|
||||
"0.1": 0.4067,
|
||||
"0.2": 0.2519,
|
||||
"0.3": 0.1753
|
||||
}
|
||||
},
|
||||
{
|
||||
"optimizer": "adam",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.4831,
|
||||
"robustness": {
|
||||
"0.1": 0.4319,
|
||||
"0.2": 0.2668,
|
||||
"0.3": 0.1618
|
||||
}
|
||||
}
|
||||
]
|
||||
|
After Width: | Height: | Size: 29 KiB |
@@ -0,0 +1,19 @@
|
||||
seed,optimizer,augmentation,test_acc,robustness
|
||||
42,sgd,none,0.7013,"{'0.1': 0.6895, '0.2': 0.6449, '0.3': 0.5712}"
|
||||
42,sgd,standard,0.6791,"{'0.1': 0.6702, '0.2': 0.6315, '0.3': 0.5601}"
|
||||
42,sgd,aggressive,0.6703,"{'0.1': 0.6505, '0.2': 0.587, '0.3': 0.5143}"
|
||||
42,adam,none,0.5658,"{'0.1': 0.5575, '0.2': 0.5204, '0.3': 0.4498}"
|
||||
42,adam,standard,0.4394,"{'0.1': 0.4385, '0.2': 0.4069, '0.3': 0.3476}"
|
||||
42,adam,aggressive,0.45,"{'0.1': 0.4467, '0.2': 0.4168, '0.3': 0.3557}"
|
||||
123,sgd,none,0.7002,"{'0.1': 0.6902, '0.2': 0.6511, '0.3': 0.5781}"
|
||||
123,sgd,standard,0.6951,"{'0.1': 0.6833, '0.2': 0.6248, '0.3': 0.5406}"
|
||||
123,sgd,aggressive,0.6766,"{'0.1': 0.6661, '0.2': 0.6188, '0.3': 0.5369}"
|
||||
123,adam,none,0.4857,"{'0.1': 0.4851, '0.2': 0.4572, '0.3': 0.4191}"
|
||||
123,adam,standard,0.4536,"{'0.1': 0.4517, '0.2': 0.4216, '0.3': 0.3551}"
|
||||
123,adam,aggressive,0.4542,"{'0.1': 0.4568, '0.2': 0.4363, '0.3': 0.3909}"
|
||||
999,sgd,none,0.6961,"{'0.1': 0.6845, '0.2': 0.6509, '0.3': 0.5702}"
|
||||
999,sgd,standard,0.6896,"{'0.1': 0.6757, '0.2': 0.6251, '0.3': 0.5515}"
|
||||
999,sgd,aggressive,0.663,"{'0.1': 0.6542, '0.2': 0.6143, '0.3': 0.5326}"
|
||||
999,adam,none,0.5189,"{'0.1': 0.5138, '0.2': 0.4787, '0.3': 0.4115}"
|
||||
999,adam,standard,0.4934,"{'0.1': 0.4822, '0.2': 0.4312, '0.3': 0.34}"
|
||||
999,adam,aggressive,0.4039,"{'0.1': 0.4053, '0.2': 0.3814, '0.3': 0.3183}"
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.0048419479751587,0.26248,1.6749397346496582,0.3807
|
||||
2,1.803226097946167,0.34018,1.5852585729598998,0.4277
|
||||
3,1.7538686734390259,0.3595,1.5596578207015992,0.44
|
||||
4,1.7227864822769166,0.36774,1.5332639671325683,0.4412
|
||||
5,1.707428610610962,0.37698,1.5266790023803711,0.446
|
||||
6,1.687287041053772,0.3853,1.5208749200820924,0.4534
|
||||
7,1.6827863134002685,0.38752,1.4709485807418823,0.4678
|
||||
8,1.6679418573379516,0.39552,1.4785521780014037,0.4569
|
||||
9,1.6778243729400635,0.38932,1.4824026586532593,0.4535
|
||||
10,1.661981136817932,0.39328,1.4527232948303224,0.4733
|
||||
11,1.6635226394271851,0.39552,1.4467358991622925,0.4761
|
||||
12,1.6477233642578124,0.39888,1.4165599590301514,0.4825
|
||||
13,1.64288627204895,0.4028,1.4934470342636108,0.4554
|
||||
14,1.64764131275177,0.40076,1.3920714778900147,0.4971
|
||||
15,1.6471917071151734,0.39942,1.4170124416351317,0.4897
|
||||
16,1.632711910057068,0.40306,1.437167017364502,0.4766
|
||||
17,1.6346079250717163,0.40256,1.4856373008728028,0.4602
|
||||
18,1.6376697371673583,0.40028,1.441266096687317,0.4795
|
||||
19,1.630727212524414,0.40478,1.405259503364563,0.4869
|
||||
20,1.6380860889434814,0.40302,1.500902590560913,0.4542
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.9924071334075928,0.2727,1.6629206789016724,0.3942
|
||||
2,1.782769969177246,0.34778,1.6269580562591552,0.4111
|
||||
3,1.7246631295776367,0.36938,1.512624101638794,0.4465
|
||||
4,1.699350281906128,0.37616,1.561422179031372,0.4365
|
||||
5,1.6983959941864013,0.37842,1.5700433086395265,0.4367
|
||||
6,1.6805000713348388,0.38506,1.4903933980941773,0.4524
|
||||
7,1.6756603227996827,0.39022,1.589101205253601,0.4264
|
||||
8,1.6851841028594972,0.38372,1.53794698677063,0.4287
|
||||
9,1.667366398010254,0.39178,1.489094132232666,0.4576
|
||||
10,1.6613886919021605,0.39288,1.5648905296325684,0.4195
|
||||
11,1.6709651499176026,0.38908,1.4680940185546876,0.4564
|
||||
12,1.6781271075439452,0.38556,1.6323107133865356,0.3965
|
||||
13,1.6682115633773804,0.38786,1.67539608707428,0.394
|
||||
14,1.6545760455322265,0.39296,1.484992802810669,0.4598
|
||||
15,1.664471877670288,0.39126,1.4973457614898682,0.4527
|
||||
16,1.6631774948501588,0.39178,1.4977839450836181,0.4523
|
||||
17,1.6620827933502198,0.39418,1.4580686128616334,0.4648
|
||||
18,1.6693129856109619,0.3915,1.4670265132904052,0.4499
|
||||
19,1.6571670418548583,0.39328,1.5775521347045898,0.4209
|
||||
20,1.6555169297027588,0.39302,1.4923007045745849,0.45
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.0259005422592162,0.26528,1.7835535663604736,0.3501
|
||||
2,1.867787606163025,0.313,1.7164635316848755,0.3679
|
||||
3,1.8262985235214233,0.33078,1.6985314193725587,0.3729
|
||||
4,1.8109176992416383,0.33478,1.6716279788970947,0.3847
|
||||
5,1.7937457250976563,0.34112,1.6525777326583861,0.3855
|
||||
6,1.7862395941925049,0.3438,1.6767313753128053,0.3771
|
||||
7,1.7756537421417236,0.34192,1.649860231399536,0.3899
|
||||
8,1.7808906607055663,0.34136,1.642415731048584,0.3915
|
||||
9,1.7794113501358033,0.33866,1.6533834308624267,0.3913
|
||||
10,1.7733372591781615,0.34298,1.679184874534607,0.3706
|
||||
11,1.7687813821411134,0.34438,1.6319498205184937,0.3919
|
||||
12,1.7672993996810913,0.34524,1.5924089086532593,0.4055
|
||||
13,1.7586094732284545,0.34774,1.5894474613189697,0.4183
|
||||
14,1.7682015967178344,0.34308,1.6354400550842285,0.3904
|
||||
15,1.7509896782684327,0.35068,1.6326687992095947,0.3796
|
||||
16,1.7584374965667724,0.34796,1.633549732208252,0.3948
|
||||
17,1.7580876070785523,0.35102,1.5947255432128906,0.4044
|
||||
18,1.750145943069458,0.35058,1.5891325847625732,0.418
|
||||
19,1.7573828769683837,0.3461,1.5960699188232421,0.4032
|
||||
20,1.7578185536956787,0.3469,1.5921159523010253,0.4039
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.8579250992584229,0.32392,1.5558856401443482,0.4313
|
||||
2,1.6377606032562255,0.4044,1.5286170751571655,0.4396
|
||||
3,1.5881630603408814,0.41738,1.5185808382034303,0.4484
|
||||
4,1.5531442667388915,0.43602,1.4579217964172364,0.4586
|
||||
5,1.5474866276931762,0.43182,1.45462444896698,0.4727
|
||||
6,1.5150492757415772,0.44956,1.4247750133514405,0.475
|
||||
7,1.5141647922515868,0.44442,1.4379277662277221,0.4691
|
||||
8,1.5064603635406495,0.44778,1.4320051357269288,0.4814
|
||||
9,1.4982417543029785,0.45322,1.4486036586761475,0.4788
|
||||
10,1.495431781539917,0.45466,1.4063307929992677,0.4955
|
||||
11,1.477474406814575,0.46004,1.3848083011627197,0.5074
|
||||
12,1.478807368774414,0.46202,1.4433747045516967,0.4806
|
||||
13,1.481698868713379,0.45994,1.4207568691253663,0.4875
|
||||
14,1.4778631386184693,0.46108,1.3489213497161865,0.5107
|
||||
15,1.4712807468032838,0.46604,1.3919997245788573,0.4893
|
||||
16,1.483530824584961,0.46112,1.4437157917022705,0.4853
|
||||
17,1.4729563687133789,0.4616,1.3912005447387696,0.496
|
||||
18,1.4708208094787598,0.46484,1.4403230434417724,0.4778
|
||||
19,1.469524460105896,0.46858,1.3713474117279052,0.5112
|
||||
20,1.4658287371444703,0.46598,1.4386186960220337,0.4857
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.7651092416763305,0.34864,1.4801026742935182,0.4452
|
||||
2,1.4773385808563233,0.4608,1.4107468299865722,0.4928
|
||||
3,1.4245695472717286,0.48826,1.3724896883010864,0.5095
|
||||
4,1.3900219127273559,0.49736,1.2948982368469237,0.5366
|
||||
5,1.3694689385223389,0.50896,1.2803543266296387,0.5367
|
||||
6,1.3547902798843383,0.51456,1.2770936399459838,0.5403
|
||||
7,1.3447893646621705,0.51488,1.2862837057113647,0.5363
|
||||
8,1.3417145071411132,0.519,1.2915318742752075,0.5294
|
||||
9,1.338515231552124,0.52032,1.247458812904358,0.5552
|
||||
10,1.3313588520050048,0.52276,1.23068475151062,0.5559
|
||||
11,1.3147309407806396,0.52892,1.2065896244049072,0.5765
|
||||
12,1.3206113652801514,0.52802,1.2014692079544067,0.5784
|
||||
13,1.3099163648223877,0.52914,1.219766689491272,0.5604
|
||||
14,1.3167620611953736,0.52754,1.1866582777023316,0.5766
|
||||
15,1.3133597641754151,0.53076,1.2739620433807373,0.5413
|
||||
16,1.3288053150939942,0.5251,1.280159574699402,0.5474
|
||||
17,1.3213993626785279,0.5305,1.3356346948623656,0.5339
|
||||
18,1.3234788956451415,0.5274,1.2488118535995483,0.5621
|
||||
19,1.3162768488693237,0.53068,1.1890131019592285,0.5786
|
||||
20,1.3078575289154053,0.53372,1.2257861446380616,0.5658
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.8653128273773194,0.32796,1.611533346939087,0.3995
|
||||
2,1.612140337791443,0.41292,1.623063315963745,0.4098
|
||||
3,1.5402255586242677,0.44038,1.4548933393478394,0.4697
|
||||
4,1.5275923446655273,0.44896,1.440786859703064,0.4802
|
||||
5,1.4979061844635009,0.45434,1.417161144065857,0.4818
|
||||
6,1.4915729135131837,0.46008,1.3997222789764405,0.4988
|
||||
7,1.4728927141571044,0.46468,1.387175965309143,0.4972
|
||||
8,1.4769946060943604,0.46384,1.466834031867981,0.4687
|
||||
9,1.4653232444000244,0.46896,1.4313162572860718,0.4816
|
||||
10,1.4634774406433106,0.47178,1.3978804470062256,0.4935
|
||||
11,1.464304472732544,0.47028,1.4218554107666015,0.4782
|
||||
12,1.4602648656463624,0.47172,1.4064278043746947,0.4873
|
||||
13,1.4479452781295776,0.47628,1.4901286466598511,0.4706
|
||||
14,1.4542893864059447,0.47666,1.3517237024307251,0.5077
|
||||
15,1.4511225904846192,0.4727,1.3578523971557617,0.5074
|
||||
16,1.4690848288726808,0.47104,1.3588831977844238,0.5053
|
||||
17,1.4494867483520508,0.47718,1.3755928480148316,0.5045
|
||||
18,1.460067735671997,0.4717,1.3654316268920899,0.5149
|
||||
19,1.4619887506103515,0.47406,1.3814095928192138,0.5031
|
||||
20,1.4534882698059082,0.47412,1.34085482711792,0.5189
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.0278418547058106,0.24944,1.8780883220672608,0.3251
|
||||
2,1.8143817639541626,0.331,1.6300555017471314,0.4045
|
||||
3,1.7264616962432862,0.3622,1.597921866607666,0.4111
|
||||
4,1.6720347478866577,0.38138,1.608277184677124,0.41
|
||||
5,1.6579552675628662,0.3885,1.5234451871871948,0.435
|
||||
6,1.6298340225982666,0.39974,1.5039015756607055,0.4597
|
||||
7,1.6189628750610352,0.40322,1.5221681720733642,0.4422
|
||||
8,1.6198538191986085,0.40798,1.4737193891525269,0.4565
|
||||
9,1.6145962512969971,0.4075,1.466329651069641,0.4629
|
||||
10,1.621643498878479,0.40454,1.6072639640808106,0.4081
|
||||
11,1.6055096750259399,0.40826,1.527749220275879,0.4358
|
||||
12,1.6175734860229491,0.40692,1.4493362300872803,0.4722
|
||||
13,1.599906416091919,0.4124,1.4468452342987062,0.4715
|
||||
14,1.5933954864883424,0.41398,1.4764380693435668,0.4626
|
||||
15,1.5933787399291992,0.41372,1.533870811843872,0.4468
|
||||
16,1.5964966577911377,0.41302,1.4542985107421875,0.467
|
||||
17,1.6135705094146728,0.40712,1.4721424545288087,0.4527
|
||||
18,1.5964763764190675,0.41558,1.4527687072753905,0.4593
|
||||
19,1.591530082244873,0.41652,1.4388617012023925,0.4642
|
||||
20,1.589130754928589,0.41602,1.4646002044677735,0.4536
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.0309293317031862,0.24762,1.7998561183929442,0.3421
|
||||
2,1.7957917514801025,0.33852,1.6648591684341432,0.3857
|
||||
3,1.7348672610855103,0.35496,1.654585442352295,0.4019
|
||||
4,1.7112164305877686,0.36548,1.5846896692276,0.4141
|
||||
5,1.6914103266525269,0.37268,1.5373584619522094,0.4395
|
||||
6,1.676686876564026,0.37698,1.637582375717163,0.3837
|
||||
7,1.6700877303695678,0.38384,1.5936841482162476,0.4161
|
||||
8,1.6569111339569091,0.38968,1.5129856206893921,0.4351
|
||||
9,1.6577720043182373,0.3908,1.499471364212036,0.4555
|
||||
10,1.6537026256942748,0.39124,1.537830891418457,0.4293
|
||||
11,1.6379817530059815,0.401,1.5753344255447388,0.4145
|
||||
12,1.64174579536438,0.39604,1.6890381210327148,0.3761
|
||||
13,1.6474020536422729,0.39648,1.5377606107711792,0.4443
|
||||
14,1.631554557762146,0.40042,1.5029741109848023,0.4553
|
||||
15,1.6246128929138184,0.40476,1.4533207773208618,0.4794
|
||||
16,1.6243007053375245,0.40454,1.4570706197738648,0.4646
|
||||
17,1.6160261626815795,0.40666,1.4699448907852173,0.4644
|
||||
18,1.6250457259750366,0.40384,1.43235032081604,0.4734
|
||||
19,1.6166238045501709,0.40848,1.6371612733840943,0.3933
|
||||
20,1.6181047994995117,0.40592,1.5512634952545166,0.4394
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.9290429125213624,0.29754,1.6215127836227416,0.4071
|
||||
2,1.7239283940124512,0.36334,1.738866775894165,0.3631
|
||||
3,1.6748943075942992,0.38214,1.5513272310256958,0.4274
|
||||
4,1.6455059506988525,0.39562,1.508594535446167,0.4514
|
||||
5,1.624495430870056,0.40286,1.5517113952636719,0.4313
|
||||
6,1.6052836887359618,0.41002,1.50381274394989,0.4484
|
||||
7,1.6106374563980101,0.40804,1.5364446432113648,0.4302
|
||||
8,1.5953116985702516,0.41482,1.4611921619415282,0.4639
|
||||
9,1.5824666970062256,0.4168,1.4630669622421264,0.4662
|
||||
10,1.5835675071334838,0.4205,1.4603066038131713,0.4595
|
||||
11,1.5734461785125733,0.42412,1.4167298784255982,0.4862
|
||||
12,1.5816114051818848,0.42086,1.4909738641738892,0.4603
|
||||
13,1.5812189110946655,0.42152,1.4088649713516235,0.4889
|
||||
14,1.5847568724822998,0.42022,1.4178233968734741,0.4823
|
||||
15,1.576828946876526,0.42272,1.4326917642593384,0.4809
|
||||
16,1.576308758392334,0.42412,1.4116268701553345,0.4911
|
||||
17,1.5560429149246215,0.43062,1.4361749185562134,0.4727
|
||||
18,1.5523672427749633,0.43476,1.4626990962982178,0.4581
|
||||
19,1.5581282339859008,0.42906,1.4140942724227905,0.4822
|
||||
20,1.5584481320190429,0.42816,1.4160622207641602,0.4934
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.0803506112289427,0.23,1.8052742357254028,0.3649
|
||||
2,1.795086921234131,0.35172,1.513138659286499,0.4536
|
||||
3,1.6374745911407471,0.4064,1.4209930673599243,0.4924
|
||||
4,1.5572372722625731,0.43516,1.3496723909378052,0.5132
|
||||
5,1.51327647895813,0.45306,1.2944782655715943,0.5399
|
||||
6,1.4705483396148682,0.47014,1.2843341438293456,0.5457
|
||||
7,1.419877719154358,0.48872,1.2024114828109742,0.5757
|
||||
8,1.3825100279998779,0.50334,1.1727403156280518,0.5862
|
||||
9,1.349350403137207,0.51694,1.1227089519500733,0.6043
|
||||
10,1.3265677324295044,0.5264,1.1120436277389527,0.6032
|
||||
11,1.2989790354156494,0.53608,1.09653370552063,0.6181
|
||||
12,1.271951921310425,0.5476,1.0683507353782653,0.6288
|
||||
13,1.252325573387146,0.55396,1.0399812242507935,0.6344
|
||||
14,1.2363645384979247,0.56054,1.005698392677307,0.6526
|
||||
15,1.2197590786361694,0.56746,0.9838214920043945,0.6531
|
||||
16,1.2122340144348145,0.5668,0.9973555490493774,0.6472
|
||||
17,1.1860296492004394,0.58086,0.9641439270019532,0.659
|
||||
18,1.1860070882797242,0.5799,0.9424298759460449,0.672
|
||||
19,1.1792696018218993,0.57868,0.9344734144210816,0.6784
|
||||
20,1.1561780276107787,0.5907,0.9342237223625183,0.6766
|
||||
|
@@ -0,0 +1,21 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.0498613494491575,0.24344,1.827182295036316,0.3315
|
||||
2,1.7751250897216797,0.35796,1.5328051206588744,0.4525
|
||||
3,1.6295417392730713,0.40922,1.4538747985839844,0.4708
|
||||
4,1.556714581451416,0.4379,1.3427102214813234,0.5272
|
||||
5,1.5068591191864014,0.45632,1.3105488367080689,0.5268
|
||||
6,1.4472922039413452,0.48044,1.2710780916213988,0.5427
|
||||
7,1.416664746017456,0.49122,1.184156767177582,0.5828
|
||||
8,1.3932555706787109,0.50056,1.1800921089172363,0.5861
|
||||
9,1.354030286064148,0.51452,1.11541519241333,0.6093
|
||||
10,1.3268484888839722,0.52786,1.0645180792808533,0.6262
|
||||
11,1.2946206281661987,0.53744,1.0727122479438782,0.621
|
||||
12,1.2794264878082275,0.54414,1.044085719871521,0.6339
|
||||
13,1.2605588024902343,0.55172,1.0073520627975463,0.6379
|
||||
14,1.2451800212478639,0.55708,0.989365707397461,0.6558
|
||||
15,1.2263247829437256,0.56498,0.9746960914611816,0.6631
|
||||
16,1.2118944458770753,0.56814,0.9805354339599609,0.6508
|
||||
17,1.2089972533798217,0.57022,0.9678337673187256,0.6643
|
||||
18,1.1906279234695434,0.57802,0.9852053064346313,0.6507
|
||||
19,1.1747162343597413,0.58308,0.9369933985710144,0.6728
|
||||
20,1.160263595199585,0.58824,0.9429312159538269,0.6703
|
||||
|