mirror of
https://github.com/ION606/COGMOD-HWI.git
synced 2026-05-14 22:16:57 +00:00
num 5
This commit is contained in:
+2
-2
@@ -224,9 +224,9 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python (pystan_env)",
|
||||
"display_name": "pystan_env",
|
||||
"language": "python",
|
||||
"name": "pystan_env"
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
|
||||
+391
-119
@@ -4,7 +4,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## HW3 Problem 3"
|
||||
"# Problem 3"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -24,136 +24,408 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"## HW3 Problem 3\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"# Define the simulate_diffusion function\n",
|
||||
"def simulate_diffusion(v, a, beta, tau, dt=1e-3, scale=1.0, max_time=10.0):\n",
|
||||
" \"\"\"\n",
|
||||
" Simulates one realization of the diffusion process given a set of parameters.\n",
|
||||
"\n",
|
||||
" Parameters:\n",
|
||||
" -----------\n",
|
||||
" v : float\n",
|
||||
" The drift rate (rate of information uptake).\n",
|
||||
" a : float\n",
|
||||
" The boundary separation (decision threshold).\n",
|
||||
" beta : float in [0, 1]\n",
|
||||
" Relative starting point (prior option preferences).\n",
|
||||
" tau : float\n",
|
||||
" Non-decision time (additive constant).\n",
|
||||
" dt : float, optional (default: 1e-3 = 0.001)\n",
|
||||
" The step size for the Euler algorithm.\n",
|
||||
" scale : float, optional (default: 1.0)\n",
|
||||
" The scale (sqrt(var)) of the Wiener process.\n",
|
||||
" max_time: float, optional (default: 10.0)\n",
|
||||
" The maximum number of seconds before forced termination.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" --------\n",
|
||||
" (rt, decision) - a tuple of response time (rt - float) and a binary decision (decision - int).\n",
|
||||
" \"\"\"\n",
|
||||
" # Initialize the process\n",
|
||||
" y = beta * a # Starting point\n",
|
||||
" num_steps = tau # Initialize time with non-decision time\n",
|
||||
" const = scale * np.sqrt(dt) # Scale for the Wiener process\n",
|
||||
"\n",
|
||||
" # Simulate the diffusion process\n",
|
||||
" while (y <= a and y >= 0) and num_steps <= max_time:\n",
|
||||
" z = np.random.randn() # Random noise from a standard normal distribution\n",
|
||||
" y += v * dt + const * z # Update evidence accumulation\n",
|
||||
" num_steps += dt # Increment time\n",
|
||||
"\n",
|
||||
" # Determine the decision and response time\n",
|
||||
" if y >= a:\n",
|
||||
" decision = 1 # Upper boundary (correct decision)\n",
|
||||
" else:\n",
|
||||
" decision = 0 # Lower boundary (incorrect decision)\n",
|
||||
" rt = round(num_steps, 3) # Round RT to 3 decimal places\n",
|
||||
"\n",
|
||||
" return (rt, decision)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Define the simulate_diffusion_n function\n",
|
||||
"def simulate_diffusion_n(num_sims, v, a, beta, tau, dt=1e-3, scale=1.0, max_time=10.0):\n",
|
||||
" \"\"\"\n",
|
||||
" Simulates multiple realizations of the diffusion process.\n",
|
||||
"\n",
|
||||
" Parameters:\n",
|
||||
" -----------\n",
|
||||
" num_sims : int\n",
|
||||
" Number of simulations to run.\n",
|
||||
" v, a, beta, tau, dt, scale, max_time : same as simulate_diffusion.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" --------\n",
|
||||
" data : numpy.ndarray of shape (num_sims, 2)\n",
|
||||
" A 2D array where each row contains:\n",
|
||||
" - The response time (RT) for the trial.\n",
|
||||
" - The decision (1 for upper boundary, 0 for lower boundary).\n",
|
||||
" \"\"\"\n",
|
||||
" data = np.zeros((num_sims, 2)) # Initialize array to store results\n",
|
||||
" for n in range(num_sims):\n",
|
||||
" data[n, :] = simulate_diffusion(v, a, beta, tau, dt, scale, max_time)\n",
|
||||
" return data\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Define the visualize_diffusion_model function\n",
|
||||
"def visualize_diffusion_model(data, figsize=(8, 6)):\n",
|
||||
" \"\"\"\n",
|
||||
" Visualizes the RT distributions for correct and incorrect decisions.\n",
|
||||
"\n",
|
||||
" Parameters:\n",
|
||||
" -----------\n",
|
||||
" data : numpy.ndarray of shape (num_sims, 2)\n",
|
||||
" The output from simulate_diffusion_n.\n",
|
||||
" figsize : tuple, optional (default: (8, 6))\n",
|
||||
" Size of the figure.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" --------\n",
|
||||
" f : matplotlib.figure.Figure\n",
|
||||
" The figure object.\n",
|
||||
" \"\"\"\n",
|
||||
" f, ax = plt.subplots(1, 1, figsize=figsize)\n",
|
||||
" sns.histplot(data[data[:, 1] == 1, 0], color='maroon', alpha=0.7, ax=ax, label='Correct responses')\n",
|
||||
" sns.histplot(data[data[:, 1] == 0, 0], color='gray', ax=ax, label='Incorrect responses')\n",
|
||||
" sns.despine(ax=ax)\n",
|
||||
" ax.set_xlabel('Response time (s)', fontsize=18)\n",
|
||||
" ax.set_ylabel('Frequency', fontsize=18)\n",
|
||||
" ax.legend(fontsize=18)\n",
|
||||
" return f\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Part 1: Explore the effect of drift rate (v) on RT distributions\n",
|
||||
"def explore_drift_rate():\n",
|
||||
" # Baseline parameters\n",
|
||||
" a = 2.0\n",
|
||||
" beta = 0.5\n",
|
||||
" tau = 0.5\n",
|
||||
" scale = 1.0\n",
|
||||
" max_time = 10.0\n",
|
||||
" dt = 1e-3\n",
|
||||
" num_sims = 2000\n",
|
||||
"\n",
|
||||
" # Vary drift rate (v)\n",
|
||||
" v_values = np.linspace(0.5, 1.5, 25)\n",
|
||||
" mean_rt_upper = []\n",
|
||||
" mean_rt_lower = []\n",
|
||||
"\n",
|
||||
" for v in v_values:\n",
|
||||
" data = simulate_diffusion_n(num_sims, v, a, beta, tau, dt, scale, max_time)\n",
|
||||
" mean_rt_upper.append(data[data[:, 1] == 1, 0].mean())\n",
|
||||
" mean_rt_lower.append(data[data[:, 1] == 0, 0].mean())\n",
|
||||
"\n",
|
||||
" # Plot results\n",
|
||||
" plt.figure(figsize=(8, 6))\n",
|
||||
" plt.plot(v_values, mean_rt_upper, label='Upper Boundary (Correct)', color='maroon')\n",
|
||||
" plt.plot(v_values, mean_rt_lower, label='Lower Boundary (Incorrect)', color='gray')\n",
|
||||
" plt.xlabel('Drift Rate (v)', fontsize=14)\n",
|
||||
" plt.ylabel('Mean Response Time (s)', fontsize=14)\n",
|
||||
" plt.legend(fontsize=12)\n",
|
||||
" plt.title('Effect of Drift Rate on Mean RTs', fontsize=16)\n",
|
||||
" plt.show()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Part 2: Explore the effects of other parameters\n",
|
||||
"def explore_parameters():\n",
|
||||
" # Baseline parameters\n",
|
||||
" v = 1.0\n",
|
||||
" a = 2.0\n",
|
||||
" beta = 0.5\n",
|
||||
" tau = 0.5\n",
|
||||
" scale = 1.0\n",
|
||||
" max_time = 10.0\n",
|
||||
" dt = 1e-3\n",
|
||||
" num_sims = 2000\n",
|
||||
"\n",
|
||||
" # Vary boundary separation (a)\n",
|
||||
" a_values = np.linspace(1.0, 3.0, 25)\n",
|
||||
" mean_rt_upper = []\n",
|
||||
" mean_rt_lower = []\n",
|
||||
" std_rt_upper = []\n",
|
||||
" std_rt_lower = []\n",
|
||||
"\n",
|
||||
" for a in a_values:\n",
|
||||
" data = simulate_diffusion_n(num_sims, v, a, beta, tau, dt, scale, max_time)\n",
|
||||
" mean_rt_upper.append(data[data[:, 1] == 1, 0].mean())\n",
|
||||
" mean_rt_lower.append(data[data[:, 1] == 0, 0].mean())\n",
|
||||
" std_rt_upper.append(data[data[:, 1] == 1, 0].std())\n",
|
||||
" std_rt_lower.append(data[data[:, 1] == 0, 0].std())\n",
|
||||
"\n",
|
||||
" # Plot results\n",
|
||||
" plt.figure(figsize=(12, 6))\n",
|
||||
" plt.subplot(1, 2, 1)\n",
|
||||
" plt.plot(a_values, mean_rt_upper, label='Upper Boundary (Correct)', color='maroon')\n",
|
||||
" plt.plot(a_values, mean_rt_lower, label='Lower Boundary (Incorrect)', color='gray')\n",
|
||||
" plt.xlabel('Boundary Separation (a)', fontsize=14)\n",
|
||||
" plt.ylabel('Mean Response Time (s)', fontsize=14)\n",
|
||||
" plt.legend(fontsize=12)\n",
|
||||
"\n",
|
||||
" plt.subplot(1, 2, 2)\n",
|
||||
" plt.plot(a_values, std_rt_upper, label='Upper Boundary (Correct)', color='maroon')\n",
|
||||
" plt.plot(a_values, std_rt_lower, label='Lower Boundary (Incorrect)', color='gray')\n",
|
||||
" plt.xlabel('Boundary Separation (a)', fontsize=14)\n",
|
||||
" plt.ylabel('Standard Deviation of RT (s)', fontsize=14)\n",
|
||||
" plt.legend(fontsize=12)\n",
|
||||
"\n",
|
||||
" plt.suptitle('Effect of Boundary Separation on RT Distributions', fontsize=16)\n",
|
||||
" plt.show()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Run the exploration functions\n",
|
||||
"explore_drift_rate()\n",
|
||||
"explore_parameters()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Problem 5\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import pystan\n",
|
||||
"import arviz as az\n",
|
||||
"# Simulate data\n",
|
||||
"N = 100\n",
|
||||
"alpha = 2.3\n",
|
||||
"beta = 4.0\n",
|
||||
"sigma = 2.0\n",
|
||||
"\n",
|
||||
"x = np.random.normal(size=N)\n",
|
||||
"y = alpha + beta * x + sigma * np.random.normal(size=N)\n",
|
||||
"\n",
|
||||
"stan_code = \"\"\"\n",
|
||||
"data {\n",
|
||||
" int<lower=1> N; // Number of observations\n",
|
||||
" vector[N] x; // Covariate\n",
|
||||
" vector[N] y; // Outcome\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"parameters {\n",
|
||||
" real alpha; // Intercept\n",
|
||||
" real beta; // Slope\n",
|
||||
" real<lower=0> sigma2; // Variance (sigma squared)\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"transformed parameters {\n",
|
||||
" real<lower=0> sigma; // Standard deviation\n",
|
||||
" sigma = sqrt(sigma2);\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"model {\n",
|
||||
" // Priors\n",
|
||||
" alpha ~ normal(0, 10);\n",
|
||||
" beta ~ normal(0, 10);\n",
|
||||
" sigma2 ~ inv_gamma(1, 1);\n",
|
||||
"\n",
|
||||
" // Likelihood\n",
|
||||
" y ~ normal(alpha + beta * x, sigma);\n",
|
||||
"}\n",
|
||||
"\"\"\"\n",
|
||||
"# Prepare data dictionary\n",
|
||||
"data_dict = {\n",
|
||||
" 'N': N,\n",
|
||||
" 'x': x,\n",
|
||||
" 'y': y\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Build and fit the model\n",
|
||||
"model = stan.build(stan_code, data=data_dict)\n",
|
||||
"fit = model.sample(num_chains=4, num_samples=1000, num_warmup=500)\n",
|
||||
"summary = az.summary(fit, var_names=['alpha', 'beta', 'sigma'])\n",
|
||||
"print(summary)\n",
|
||||
"\n",
|
||||
"# Trace plots for convergence diagnostics\n",
|
||||
"az.plot_trace(fit, var_names=['alpha', 'beta', 'sigma'], compact=False, legend=True)\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()\n",
|
||||
"\n",
|
||||
"# Plot posterior distributions\n",
|
||||
"az.plot_posterior(fit, var_names=['alpha', 'beta', 'sigma'])\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Simulate larger dataset\n",
|
||||
"N_large = 1000\n",
|
||||
"x_large = np.random.normal(size=N_large)\n",
|
||||
"y_large = alpha + beta * x_large + sigma * np.random.normal(size=N_large)\n",
|
||||
"\n",
|
||||
"# Prepare data dictionary for larger dataset\n",
|
||||
"data_dict_large = {\n",
|
||||
" 'N': N_large,\n",
|
||||
" 'x': x_large,\n",
|
||||
" 'y': y_large\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Fit the model with larger dataset\n",
|
||||
"fit_large = model.sample(data=data_dict_large, num_chains=4, num_samples=1000, num_warmup=500)\n",
|
||||
"\n",
|
||||
"# Summarize results\n",
|
||||
"summary_large = az.summary(fit_large, var_names=['alpha', 'beta', 'sigma'])\n",
|
||||
"print(summary_large)\n",
|
||||
"\n",
|
||||
"# Compare uncertainty\n",
|
||||
"print(\"Uncertainty (N=100):\")\n",
|
||||
"print(summary[['mean', 'sd', 'hdi_3%', 'hdi_97%']])\n",
|
||||
"print(\"\\nUncertainty (N=1000):\")\n",
|
||||
"print(summary_large[['mean', 'sd', 'hdi_3%', 'hdi_97%']])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's start with the simplest problem of all: Bayesian estimation of the mean and variance from a sample of data points. Our model is:\n",
|
||||
"$$\n",
|
||||
"\\begin{align}\n",
|
||||
" \\mu &\\sim \\mathcal{N}(0, 3)\\\\\n",
|
||||
" \\sigma^2 &\\sim \\text{Inv-Gamma}(1, 1)\\\\\n",
|
||||
" y_n &\\sim \\mathcal{N}(\\mu, \\sigma^2) \\quad \\text{for} \\,\\, n = 1,\\dots,N\n",
|
||||
"\\end{align}\n",
|
||||
"$$"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import seaborn as sns\n",
|
||||
"\n",
|
||||
"def simulate_diffusion(v, a, beta, tau, dt=1e-3, scale=1.0, max_time=10., rng=None):\n",
|
||||
" \"\"\"\n",
|
||||
" Simulates one realization of the diffusion process given\n",
|
||||
" a set of parameters and a step size `dt`.\n",
|
||||
"import stan\n",
|
||||
"\n",
|
||||
" Parameters:\n",
|
||||
" -----------\n",
|
||||
" v : float\n",
|
||||
" The drift rate (rate of information uptake)\n",
|
||||
" a : float\n",
|
||||
" The boundary separation (decision threshold).\n",
|
||||
" beta : float in [0, 1]\n",
|
||||
" Relative starting point (prior option preferences)\n",
|
||||
" tau : float\n",
|
||||
" Non-decision time (additive constant)\n",
|
||||
" dt : float, optional (default: 1e-3 = 0.001)\n",
|
||||
" The step size for the Euler algorithm.\n",
|
||||
" scale : float, optional (default: 1.0)\n",
|
||||
" The scale (sqrt(var)) of the Wiener process. Not considered\n",
|
||||
" a parameter and typically fixed to either 1.0 or 0.1.\n",
|
||||
" max_time : float, optional (default: .10)\n",
|
||||
" The maximum number of seconds before forced termination.\n",
|
||||
" rng : np.random.Generator or None, optional (default: None)\n",
|
||||
" A random number generator with locally set seed or None\n",
|
||||
" If None provided, a new generator will be spawned within the function.\n",
|
||||
" \n",
|
||||
" Returns:\n",
|
||||
" --------\n",
|
||||
" (x, c) - a tuple of response time (y - float) and a \n",
|
||||
" binary decision (c - int) \n",
|
||||
" \"\"\"\n",
|
||||
"try:\n",
|
||||
" import arviz as az\n",
|
||||
"except ImportError as err:\n",
|
||||
" print(\"Please, install arviz for easy visualization of Stan models.\")\n",
|
||||
"\n",
|
||||
" # Inits (process starts at relative starting point)\n",
|
||||
" y = beta * a\n",
|
||||
" num_steps = tau\n",
|
||||
" const = scale * np.sqrt(dt)\n",
|
||||
" if rng is None:\n",
|
||||
" rng = np.random.default_rng()\n",
|
||||
"import nest_asyncio\n",
|
||||
"nest_asyncio.apply()\n",
|
||||
"### Simulate data\n",
|
||||
"mu = 2.5\n",
|
||||
"sigma = 3\n",
|
||||
"N = 50\n",
|
||||
"y = np.random.normal(mu, sigma, size=N)\n",
|
||||
"sns.histplot(y, color='gray')\n",
|
||||
"### Create data dictionary\n",
|
||||
"data_dict = {\n",
|
||||
" 'y': y,\n",
|
||||
" 'N': N\n",
|
||||
"}\n",
|
||||
"program_code = \"\"\"\n",
|
||||
"\n",
|
||||
" # Loop through process and check boundary conditions\n",
|
||||
" while (y <= a and y >= 0) and num_steps <= max_time:\n",
|
||||
" # Perform diffusion equation\n",
|
||||
" z = rng.normal()\n",
|
||||
" y += v * dt + const * z\n",
|
||||
"data {\n",
|
||||
" int<lower=1> N;\n",
|
||||
" vector[N] y;\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
" # Increment step counter\n",
|
||||
" num_steps += dt\n",
|
||||
"parameters {\n",
|
||||
" real mu;\n",
|
||||
" real<lower=0> sigma2;\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
" if y >= a:\n",
|
||||
" c = 1.\n",
|
||||
" else:\n",
|
||||
" c = 0.\n",
|
||||
" return (round(num_steps, 4), c)\n",
|
||||
"transformed parameters {\n",
|
||||
" real<lower=0> sigma;\n",
|
||||
" sigma = sqrt(sigma2);\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"def simulate_diffusion_n(num_sims, v, a, beta, tau, dt=1e-3, scale=1.0, max_time=10., rng=None):\n",
|
||||
" \"\"\"Simulate multiple realizations of the diffusion process.\"\"\"\n",
|
||||
" \n",
|
||||
" # Inits\n",
|
||||
" data = np.zeros((num_sims, 2))\n",
|
||||
" if rng is None:\n",
|
||||
" rng = np.random.default_rng()\n",
|
||||
" \n",
|
||||
" # Create data set\n",
|
||||
" for n in range(num_sims):\n",
|
||||
" data[n, :] = simulate_diffusion(v, a, beta, tau, dt, scale, max_time, rng)\n",
|
||||
" return data\n",
|
||||
"model {\n",
|
||||
" // Priors\n",
|
||||
" mu ~ normal(0, 3);\n",
|
||||
" sigma2 ~ inv_gamma(1, 1);\n",
|
||||
"\n",
|
||||
"def visualize_data(data, figsize=(10, 5)):\n",
|
||||
" \"\"\"Helper function to visualize a simple response time data set.\"\"\"\n",
|
||||
"\n",
|
||||
" f, axarr = plt.subplots(1, 2, figsize=figsize)\n",
|
||||
" \n",
|
||||
" # Histogram of response times\n",
|
||||
" sns.histplot(\n",
|
||||
" data[:, 0][data[:, 1] == 1], ax=axarr[0], color='#AA0000', alpha=0.8, lw=2, label=f'Response 1')\n",
|
||||
" sns.histplot(\n",
|
||||
" data[:, 0][data[:, 1] == 0], ax=axarr[0], color='#0000AA', alpha=0.8, lw=2, label=f'Response 0')\n",
|
||||
"\n",
|
||||
" # Barplot of categorical responses\n",
|
||||
" response, frequency = np.unique(data[:, 1], return_counts=True)\n",
|
||||
" sns.barplot(x=response.astype(np.int32), y=frequency, ax=axarr[1], alpha=0.8, color='#00AA00')\n",
|
||||
"\n",
|
||||
" # Labels and embelishments\n",
|
||||
" axarr[0].set_xlabel('Response time (s)', fontsize=16)\n",
|
||||
" axarr[0].legend(fontsize=16)\n",
|
||||
" axarr[0].set_ylabel('Count', fontsize=16)\n",
|
||||
" axarr[1].set_xlabel('Response', fontsize=16)\n",
|
||||
" axarr[1].set_ylabel('Frequency', fontsize=16)\n",
|
||||
" for ax in axarr:\n",
|
||||
" sns.despine(ax=ax)\n",
|
||||
" ax.grid(alpha=0.1, color='black')\n",
|
||||
"\n",
|
||||
" f.suptitle('Data Summary', fontsize=18)\n",
|
||||
"\n",
|
||||
" f.tight_layout()\n",
|
||||
"\n",
|
||||
"# Baseline parameters\n",
|
||||
"a = 2.0\n",
|
||||
"beta = 0.5\n",
|
||||
"tau = 0.5\n",
|
||||
"scale = 1.0\n",
|
||||
"max_time = 10.0\n",
|
||||
"dt = 1e-3\n",
|
||||
"num_sims = 2000\n",
|
||||
"\n",
|
||||
"# Vary drift rate\n",
|
||||
"v_values = np.linspace(0.5, 1.5, 25)\n",
|
||||
"mean_rt_upper = []\n",
|
||||
"mean_rt_lower = []\n",
|
||||
"\n",
|
||||
"for v in v_values:\n",
|
||||
" data = simulate_diffusion_n(num_sims, v, a, beta, tau, dt, scale, max_time)\n",
|
||||
" mean_rt_upper.append(data[data[:, 1] == 1, 0].mean())\n",
|
||||
" mean_rt_lower.append(data[data[:, 1] == 0, 0].mean())\n",
|
||||
"\n",
|
||||
"# Plot results\n",
|
||||
"plt.figure(figsize=(8, 6))\n",
|
||||
"plt.plot(v_values, mean_rt_upper, label='Upper Boundary (Correct)', color='maroon')\n",
|
||||
"plt.plot(v_values, mean_rt_lower, label='Lower Boundary (Incorrect)', color='gray')\n",
|
||||
"plt.xlabel('Drift Rate (v)')\n",
|
||||
"plt.ylabel('Mean Response Time (s)')\n",
|
||||
"plt.legend()\n",
|
||||
"plt.title('Effect of Drift Rate on Mean RTs')\n",
|
||||
"plt.show()"
|
||||
" // Data model (likelihood)\n",
|
||||
" for (n in 1:N) {\n",
|
||||
" y[n] ~ normal(mu, sigma);\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"\"\"\"\n",
|
||||
"### Build model\n",
|
||||
"model = stan.build(program_code, data=data_dict)\n",
|
||||
"### Fit model\n",
|
||||
"fit = model.sample(num_chains=4, num_samples=1000, num_warmup=500)\n",
|
||||
"### Explore raw model outouts\n",
|
||||
"results_df = fit.to_frame()\n",
|
||||
"results_df.head()\n",
|
||||
"results_df.shape\n",
|
||||
"prior_samples = np.random.normal(0, 3, size=4000)\n",
|
||||
"post_samples = results_df.mu.values\n",
|
||||
"f, ax = plt.subplots(1, 1, figsize=(8, 4))\n",
|
||||
"# sns.histplot(prior_samples, ax=ax, color=\"gray\", alpha=0.8, label=\"Prior\")\n",
|
||||
"sns.histplot(post_samples, color='maroon', ax=ax, label=\"Posterior\")\n",
|
||||
"ax.axvline(np.mean(post_samples), color=\"maroon\", linestyle=\"dotted\", lw=4)\n",
|
||||
"ax.axvline(mu, color=\"black\", linestyle=\"dashed\")\n",
|
||||
"sns.despine(ax=ax)\n",
|
||||
"ax.legend()\n",
|
||||
"### Summarize model\n",
|
||||
"az.summary(fit)\n",
|
||||
"### Visual inspection and diagnostics\n",
|
||||
"az.plot_trace(fit, var_names=['mu', 'sigma'], compact=False, legend=True)\n",
|
||||
"plt.tight_layout()\n",
|
||||
"### Forest plots\n",
|
||||
"az.plot_forest(fit, var_names=['mu', 'sigma'], combined=True)\n",
|
||||
"### Plot KDEs\n",
|
||||
"ax = az.plot_kde(results_df.mu, \n",
|
||||
" results_df.sigma, hdi_probs=[0.393, 0.86, 0.99])\n",
|
||||
"ax.axvline(mu, color='black', linestyle='--')\n",
|
||||
"ax.axhline(sigma, color='black', linestyle='-.')\n",
|
||||
"ax.set_xlabel(r\"$\\mu$\")\n",
|
||||
"ax.set_ylabel(r\"$\\sigma$\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
Reference in New Issue
Block a user