mirror of
https://github.com/ION606/COGMOD-HWI.git
synced 2026-05-14 22:16:57 +00:00
309 lines
507 KiB
Plaintext
309 lines
507 KiB
Plaintext
|
|
{
|
||
|
|
"cells": [
|
||
|
|
{
|
||
|
|
"cell_type": "code",
|
||
|
|
"execution_count": null,
|
||
|
|
"metadata": {},
|
||
|
|
"outputs": [],
|
||
|
|
"source": [
|
||
|
|
"# DATA\n",
|
||
|
|
"\n",
|
||
|
|
"import pandas as pd\n",
|
||
|
|
"import stan\n",
|
||
|
|
"import numpy as np\n",
|
||
|
|
"import matplotlib.pyplot as plt\n",
|
||
|
|
"import seaborn as sns\n",
|
||
|
|
"import arviz as az\n",
|
||
|
|
"\n",
|
||
|
|
"# stan problems\n",
|
||
|
|
"import nest_asyncio\n",
|
||
|
|
"nest_asyncio.apply()\n",
|
||
|
|
"\n",
|
||
|
|
"data = pd.read_csv(\"sample_response_times.csv\", sep=';')\n",
|
||
|
|
"\n",
|
||
|
|
"data.head()"
|
||
|
|
]
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"cell_type": "code",
|
||
|
|
"execution_count": null,
|
||
|
|
"metadata": {},
|
||
|
|
"outputs": [],
|
||
|
|
"source": [
|
||
|
|
"stan_code = \"\"\"\n",
|
||
|
|
"data {\n",
|
||
|
|
" int<lower=1> N;\n",
|
||
|
|
" array[N] real<lower=0> y;\n",
|
||
|
|
" array[N] int<lower=1, upper=2> condition;\n",
|
||
|
|
" array[N] int<lower=0, upper=1> choice;\n",
|
||
|
|
"}\n",
|
||
|
|
"\n",
|
||
|
|
"// EVERYTHING ABOVE 0!!!! EVERYTHING!!!!!!\n",
|
||
|
|
"parameters {\n",
|
||
|
|
" real<lower=0> v_easy; // Drift rate for easy condition\n",
|
||
|
|
" real<lower=0> v_hard; // Drift rate for hard condition\n",
|
||
|
|
" real<lower=0> a; // Boundary separation\n",
|
||
|
|
" real<lower=0, upper=1> beta; // Starting point bias\n",
|
||
|
|
" real<lower=0, upper=min(y)> tau; // Non-decision time with upper bound\n",
|
||
|
|
"}\n",
|
||
|
|
"\n",
|
||
|
|
"model {\n",
|
||
|
|
" // Priors (updated)\n",
|
||
|
|
" v_easy ~ gamma(1, 2);\n",
|
||
|
|
" v_hard ~ gamma(1, 2);\n",
|
||
|
|
" a ~ gamma(2, 0.5);\n",
|
||
|
|
" beta ~ beta(2, 2);\n",
|
||
|
|
" tau ~ gamma(1, 10);\n",
|
||
|
|
"\n",
|
||
|
|
" // Likelihood (unchanged)\n",
|
||
|
|
" for (n in 1:N) {\n",
|
||
|
|
" if (condition[n] == 1) {\n",
|
||
|
|
" if (choice[n] == 1) {\n",
|
||
|
|
" y[n] ~ wiener(a, tau, beta, v_easy);\n",
|
||
|
|
" } else {\n",
|
||
|
|
" y[n] ~ wiener(a, tau, 1 - beta, -v_easy);\n",
|
||
|
|
" }\n",
|
||
|
|
" }\n",
|
||
|
|
" if (condition[n] == 2) {\n",
|
||
|
|
" if (choice[n] == 1) {\n",
|
||
|
|
" y[n] ~ wiener(a, tau, beta, v_hard);\n",
|
||
|
|
" } else {\n",
|
||
|
|
" y[n] ~ wiener(a, tau, 1 - beta, -v_hard);\n",
|
||
|
|
" }\n",
|
||
|
|
" }\n",
|
||
|
|
" }\n",
|
||
|
|
"}\n",
|
||
|
|
"\"\"\""
|
||
|
|
]
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"cell_type": "code",
|
||
|
|
"execution_count": null,
|
||
|
|
"metadata": {},
|
||
|
|
"outputs": [],
|
||
|
|
"source": [
|
||
|
|
"stan_data = {\n",
|
||
|
|
" \"N\": len(data),\n",
|
||
|
|
" \"y\": data[\"rt\"].astype(float).values,\n",
|
||
|
|
" \"condition\": data[\"condition\"].astype(int).values,\n",
|
||
|
|
" \"choice\": data[\"choice\"].astype(int).values\n",
|
||
|
|
"}\n",
|
||
|
|
"\n",
|
||
|
|
"model = stan.build(program_code=stan_code, data=stan_data)\n",
|
||
|
|
"fit = model.sample(num_chains=4, num_samples=2000)\n",
|
||
|
|
"\n",
|
||
|
|
"# Diagnostics\n",
|
||
|
|
"print(fit) # R-hat and EES"
|
||
|
|
]
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"cell_type": "code",
|
||
|
|
"execution_count": null,
|
||
|
|
"metadata": {},
|
||
|
|
"outputs": [
|
||
|
|
{
|
||
|
|
"data": {
|
||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA7wAAANACAYAAAAIE2FVAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzsvXeYJWd15/+pdOvmzj05apQTEkiAbILBNiIY7PWuzRpjg1mwsXd/Xu86riNO4PUua7AxBtsYMBhwAgECAQIkoTwaaTQ5d843p8pVvz/equruCZoZpZFG7+d55pmZ7hvqVrrvOed7vkeJoihCIpFIJBKJRCKRSCSSiwz1Qm+ARCKRSCQSiUQikUgkzwQy4JVIJBKJRCKRSCQSyUWJDHglEolEIpFIJBKJRHJRIgNeiUQikUgkEolEIpFclMiAVyKRSCQSiUQikUgkFyUy4JVIJBKJRCKRSCQSyUWJDHglEolEIpFIJBKJRHJRIgNeiUQikUgkEolEIpFclMiAVyKRSCQSiUQikUgkFyUy4JVIJBKJRCKRSCQSyUWJDHglEolEIpFIJBKJRHJRIgNeieQs/NM//ROKonDw4MFTfnfLLbfw2te+9pxep9vt8uu//uts3bqVTCbD9u3b+cAHPkAUReljWq0W//W//leuvPJKCoUC69at481vfjOHDh1a9Vrz8/P87M/+LOvXr8c0TdatW8eb3vQmFhcXcRyHkZERfuVXfuWUbfjIRz6CrutMT0+f516QSCQSieTiRX7XSyQXLzLglUjOwo/+6I9SLBb5x3/8x1U/P3HiBA888AA//dM/fdbX8H2fW2+9lU984hP8yq/8Cl//+td55zvfye///u/zm7/5m+nj2u02vu/zvve9j6997Wt85CMfodfr8fKXv5z5+fn0cW9/+9t54IEH+PM//3O+9a1v8eEPf5gNGzbQ6/UwTZN3vvOdfPrTn8ZxnFXb8bGPfYw3vvGNbNy48SnuFYlEIpFILh7kd71EchETSSSSs/KzP/uz0ZYtW6IwDNOfve9974tyuVzUbDbP+vxPf/rTERDdd999q37+h3/4h1Emk4kqlcppn+f7ftTtdqNCoRB98IMfTH9eKBSiD33oQ2d8v+PHj0eKokSf+cxn0p/df//9ERDdfvvtZ91eiUQikUheaMjveonk4kRWeCWSc+Dtb387ExMTfO9730t/9tnPfpY3v/nNlMvlsz7/jjvu4JJLLuHmm2/G9/30z6233orrujz00EPpY//5n/+Zl770pfT396PrOoVCgW63y+HDh9PH3HTTTfz5n/85H/7wh9m3b98qqRTA9u3bed3rXsfHP/7x9Gcf//jH2bx5M7feeutT2RUSiUQikVyUyO96ieTiRAa8Esk58AM/8ANs2LCBz3zmMwDs3LmTI0eO8Pa3v/2cnr+4uMjx48cxDGPVn5tvvhmAarUKwFe+8hV+8id/kpe97GV8/vOf56GHHmLnzp2MjIxg23b6el/4whd485vfzJ/92Z9x7bXXsnHjRv74j/+YMAzTx7z3ve/lnnvu4fDhwzQaDb7whS/wX/7Lf0FV5WUvkUgkEsnJyO96ieTiRL/QGyCRPB9QVZW3ve1tfPzjH+cv//Iv+cxnPsPIyAive93rzun5Q0ND7Nixg8997nOn/f22bdsA+PznP89rXvMaPvShD6W/c12XWq226vGjo6N85CMf4SMf+QiHDx/mU5/6FL/7u7/LmjVrePe73w3Am970JjZv3szHP/5xtm7diud5vOtd73oyH18ikUgkkose+V0vkVycyIBXIjlH3v72t/O///f/5rbbbuMLX/gCb33rW9H1c7uEbr31Vr74xS/S19fHpZdeesbH9Xo9DMNY9bNPfepTBEFwxudcfvnl/Omf/il/8zd/w759+9Kfq6rKe97zHv7f//t/rFmzhje96U2sX7/+nLZXIpFIJJIXIvK7XiK5+FCikxsCJBLJGbnhhhtYWlpiZmaGhx56KJUpnQ3P8/jBH/xBjh8/zq/+6q9y7bXX4roux44d48tf/jJf/epXMU2Tj33sY7z3ve/lfe97H7fccgsPPvggH/nIR7Asi7e85S188pOfpNls8oM/+IO87W1v44orrsAwDL70pS/xV3/1V3z1q1/ljW98Y/q+CwsLbNq0Cc/z+NrXvsbrX//6Z2rXSCQSiURyUSC/6yWSiwtZ4ZVIzoO3v/3t/M//+T+57LLLzvkLEMAwDL7xjW/wgQ98gL/5m79hbGyMYrHIjh07eMMb3pBmet/97nczNTXFRz/6Ud7//vdz0003cfvtt/NjP/Zj6Wtls1luvPFG/vZv/5aJiQk0TePyyy/nc5/73KovQIA1a9bw/d///Zw4ceKcJVkSiUQikbyQkd/1EsnFhazwSiQXMZVKhU2bNvE7v/M7/PZv//aF3hyJRCKRSCRPM/K7XiJ5YmSFVyK5CFlaWuLw4cN88IMfxDAMfv7nf/5Cb5JEIpFIJJKnEfldL5GcGzLglUieIkEQnDIbbyWqqj7r4wFuv/123vnOd7J161Y+85nPMDw8/Ky+v0QikUgkFxPyu14ief4iJc0SyVPk1a9+NXffffcZf//7v//7/MEf/MGzt0ESiUQikUieVuR3vUTy/EUGvBLJU+Tw4cO02+0z/n79+vVyRIBEIpFIJM9j5He9RPL8RQa8EolEIpFIJBKJRCK5KHlKPbxhGDI7O0upVEJRlKdrmyQSiUQiedJEUUS73Wb9+vXPek/dxYj8rpdIJBLJc43z+a5/SgHv7OwsmzZteiovIZFIJBLJM8LU1BQbN2680JvxvEd+10skEonkucq5fNc/pYC3VCqlb1Qul5/KS0kkEolE8rTQarXYtGlT+h0leWrI73qJRCKRPNc4n+/6pxTwJtKmcrksvwQlEolE8pxCym+fHuR3vUQikUieq5zLd71sbpJIJBKJRCKRSCQSyUWJDHglkmeRpuXx8FiNqVoPgJbt8dmHJtL/SyQSiUQieeGx1HawveBCb4ZEclEiA16J5Bmm0nH4+D3Hectf3csNf/hNfuJjD/CVPbMATNV6/PYX93FoXsz2u+9YhZ/62wf5xL1jTNdlECyRSCQSyQuB+49XePBE9UJvhkRyUfKUenglEsmZWWzb/NV3jvG5hyfxgogbN/fzy6+9jBu39HPFWtEHd8XaMg//9mspmQYAigK1rssffvUAf/jVA7xs+yA/8ZJNvOHadWQN7UJ+HIlEIpFIJM8Qr7h0hKwh61ASyTOBEkVR9GSf3Gq16Ovro9lsSiMLiSQmiiI++9Ak7//aQRw/5K03b+Kd37eNS0aKZ3yOF4TMNizato/lBSw0bXZPNbjz4ALj1R7DRZP/8optvO2lmylljWfx00gkzz/kd9PTi9yfEolEInmucT7fTbLCK5E8zfz9vWP88e0HeeVlI/zRW65my1Bh1e97rs++mRaPTzXYPd3g0FyLiWoPPzw196QAmwdzgMIHvn6Ij919nN9901X8hxvlbFGJRCKRSC4Wbts9w/bhItdu7LvQmyKRnDOPTzVw/JCbtw1e6E15QmTAK5E8TURRhKIo/PTLtrC2L8sbr12Hoii0bI+dYzUePFHlwRM19s82SWLbLUN5rlpX5vXXrGPbcIH+vEHO0IgQvb/jlR47x2tpX08YwYHZFv/hxgv3OSUSiUTy/MULQlqWx1DRvNCbIjmJqXpPBrySC8rh+TaqApeuObc59nNNG8d/7putyYBXInka2Dle449vP8jf/+xLGCpkuHS0xEfvPs53Di7y6GSdMAJTV3nxlgH+22su5UWb+7l+Yz+Dhcw5vX6j5/Jvj87w8buP83f3jjHbtLhqXZmFlsPv/chVGJrs+5FIJBLJ2Xl0os58y+YN166T3x3PIX7gilEy8nhILjCH5lvAuQe8fTkDL3jue8zIgFcieRpQFQXHC/ir7xzjzoMLTNctAK7f1M//99pL2TFS5JGJGj/9si3sGC1xbLHD+792kHe/cjuXrSlhewFRBLnM6W8a/fkM7/r+bbz9ZVv4+3vH+NC3j3DngQWuXFdGV88+cFsikUgkEoCu6wPg+KE
|
||
|
|
"text/plain": [
|
||
|
|
"<Figure size 1200x1000 with 10 Axes>"
|
||
|
|
]
|
||
|
|
},
|
||
|
|
"metadata": {},
|
||
|
|
"output_type": "display_data"
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"name": "stdout",
|
||
|
|
"output_type": "stream",
|
||
|
|
"text": [
|
||
|
|
"v_easy: mean = 2.47, 95% HDI = [1.96736185 2.98469113]\n",
|
||
|
|
"v_hard: mean = 0.28, 95% HDI = [0.01384723 0.66034203]\n"
|
||
|
|
]
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"data": {
|
||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqoAAAIQCAYAAABJ1Ex5AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAAPYQAAD2EBqD+naQAAUDhJREFUeJzt3XtUVPX+//HXgDKQcUnk4gXvylErvB0V66SWRWqW55R67JxETauTWorZL7t4yQrLTPsVaWZqpyI1T1odzSQvuUqrrxdWZWUZNpIJIV5GkAYG9u+Pfs63kYsMM8AGno+19qr58Nmf/Z7Z4/Biz96fbTEMwxAAAABgMn61XQAAAABQFoIqAAAATImgCgAAAFMiqAIAAMCUCKoAAAAwJYIqAAAATImgCgAAAFMiqAIAAMCUCKoAAAAwJYIqgGrVtm1bjRs3rrbLqPcWLlyo9u3by9/fX927d6/tcgDAJwiqACpt9erVslgs2rt3b5k/HzhwoC6//HKvt7N582bNnTvX63Eaiq1bt+rBBx/UVVddpVWrVumpp56q7ZIAwCca1XYBAOq3Q4cOyc/Ps7+JN2/erJSUFMJqJW3fvl1+fn569dVXFRAQUNvlAIDPcEQVQLWyWq1q3LhxbZfhkfz8/NouwSO//vqrgoKCCKkA6h2CKoBqdeE5qkVFRZo3b546deqkwMBAhYeH6+qrr1ZaWpokady4cUpJSZEkWSwW13Jefn6+ZsyYoZiYGFmtVsXGxurZZ5+VYRhu2y0oKNB9992nZs2aKTg4WDfffLOOHTsmi8XidqR27ty5slgs+uabb3T77bfrsssu09VXXy1J+vLLLzVu3Di1b99egYGBio6O1oQJE5Sbm+u2rfNjfP/99/rnP/+p0NBQRURE6LHHHpNhGMrMzNQtt9yikJAQRUdHa9GiRZV67ZxOp+bPn68OHTrIarWqbdu2evjhh+VwOFx9LBaLVq1apfz8fNdrtXr16jLHmzJlii699FKdO3eu1M/GjBmj6OhoFRcXV6o2STp27JgmTJigqKgoWa1WdevWTStXrnTrU1hYqNmzZ6tXr14KDQ1VkyZN9Je//EU7duwoNd6aNWvUq1cvBQcHKyQkRFdccYWef/55SVJGRoYsFosWL15car3du3fLYrHorbfeqnTtAOoGvvoH4LEzZ87oxIkTpdqLioouuu7cuXOVnJysiRMnqk+fPrLb7dq7d6/279+v66+/Xnfffbd++eUXpaWl6fXXX3db1zAM3XzzzdqxY4fuvPNOde/eXR9++KFmzpypY8eOuYWYcePGad26dbrjjjvUr18/ffzxxxo2bFi5dY0cOVKdOnXSU0895Qq9aWlpysjI0Pjx4xUdHa2DBw9q+fLlOnjwoD777DO3AC1Jo0ePVpcuXbRgwQJt2rRJTzzxhJo2baqXX35Z1157rZ5++mm9+eabeuCBB/TnP/9Z11xzTYWv1cSJE/Xaa6/ptttu04wZM/T5558rOTlZ3377rTZs2CBJev3117V8+XJ98cUXWrFihSSpf//+ZY43evRopaSkaNOmTRo5cqSr/dy5c3r//fc1btw4+fv7V1jTednZ2erXr58sFoumTJmiiIgIffDBB7rzzjtlt9s1bdo0SZLdbteKFSs0ZswYTZo0SWfPntWrr76qhIQEffHFF64Lv9LS0jRmzBhdd911evrppyVJ3377rT799FPdf//9at++va666iq9+eabmj59ulstb775poKDg3XLLbdUqnYAdYgBAJW0atUqQ1KFS7du3dzWadOmjZGYmOh6HBcXZwwbNqzC7UyePNko6+Np48aNhiTjiSeecGu/7bbbDIvFYhw+fNgwDMPYt2+fIcmYNm2aW79x48YZkow5c+a42ubMmWNIMsaMGVNqe+fOnSvV9tZbbxmSjF27dpUa46677nK1OZ1Oo1WrVobFYjEWLFjgaj916pQRFBTk9pqUJT093ZBkTJw40a39gQceMCQZ27dvd7UlJiYaTZo0qXA8wzCMkpISo2XLlsatt97q1r5u3bpSz+li7rzzTqN58+bGiRMn3Nr//ve/G6Ghoa7Xzul0Gg6Hw63PqVOnjKioKGPChAmutvvvv98ICQkxnE5nudt8+eWXDUnGt99+62orLCw0mjVrdtHXE0DdxFf/ADyWkpKitLS0UsuVV1550XXDwsJ08OBB/fDDDx5vd/PmzfL399d9993n1j5jxgwZhqEPPvhAkrRlyxZJ0r333uvWb+rUqeWOfc8995RqCwoKcv3/b7/9phMnTqhfv36SpP3795fqP3HiRNf/+/v7q3fv3jIMQ3feeaerPSwsTLGxscrIyCi3Fun35ypJSUlJbu0zZsyQJG3atKnC9ctisVg0cuRIbd68WXl5ea72tWvXqmXLlq5THi7GMAz95z//0fDhw2UYhk6cOOFaEhISdObMGdfr4+/v7zp3tqSkRCdPnpTT6VTv3r3dXsOwsDDl5+e7TgEpy6hRoxQYGKg333zT1fbhhx/qxIkT+uc//+nRawGgbiCoAvBYnz59NHjw4FLLZZdddtF1H3/8cZ0+fVqdO3fWFVdcoZkzZ+rLL7+s1HZtNptatGih4OBgt/YuXbq4fn7+v35+fmrXrp1bv44dO5Y79oV9JenkyZO6//77FRUVpaCgIEVERLj6nTlzplT/1q1buz0ODQ1VYGCgmjVrVqr91KlT5dbyx+dwYc3R0dEKCwtzPVdPjR49WgUFBXrvvfckSXl5edq8ebNGjhxZ6lSG8uTk5Oj06dNavny5IiIi3Jbx48dL+v0Cr/Nee+01XXnlla5zkiMiIrRp0ya31/Dee+9V586dNWTIELVq1UoTJkxw/cFxXlhYmIYPH67U1FRX25tvvqmWLVvq2muvrdLrAcDcCKoAatQ111yjH3/8UStXrtTll1+uFStWqGfPnq7zK2vLH4+enjdq1Ci98soruueee/TOO+9o69atrvBUUlJSqn9Z53eWd86nccHFX+WpbHisrH79+qlt27Zat26dJOn9999XQUGBRo8eXekxzj/3f/7zn2UeWU9LS9NVV10lSXrjjTc0btw4dejQQa+++qq2bNmitLQ0XXvttW6vYWRkpNLT0/Xee++5zkMeMmSIEhMT3bY9duxYZWRkaPfu3Tp79qzee+89jRkzxuMp0ADUDVxMBaDGNW3aVOPHj9f48eOVl5ena665RnPnznV9dV5eOGvTpo0++ugjnT171u2o6nfffef6+fn/lpSU6MiRI+rUqZOr3+HDhytd46lTp7Rt2zbNmzdPs2fPdrVX5ZSFqjj/HH744QfXEWPp94uYTp8+7XquVTFq1Cg9//zzstvtWrt2rdq2bes6paEyIiIiFBwcrOLiYg0ePLjCvuvXr1f79u31zjvvuO3XOXPmlOobEBCg4cOHa/jw4SopKdG9996rl19+WY899pjryPKNN96oiIgIvfnmm+rbt6/OnTunO+64o9K1A6hb+BMUQI26cGqnSy+9VB07dnSbcqlJkyaSpNOnT7v1HTp0qIqLi/Xiiy+6tS9evFgWi0VDhgyRJCUkJEiSXnrpJbd+L7zwQqXrPH8k9MIjn0uWLKn0GN4YOnRomdt77rnnJKnCGQwuZvTo0XI4HHrttde0ZcsWjRo1yqP1/f39deutt+o///mPvv7661I/z8nJcesrub+On3/+ufbs2eO2zoXvCz8/P9c5z398bzRq1EhjxozRunXrtHr1al1xxRWVOjcaQN3EEVUANapr164aOHCgevXqpaZNm2rv3r1av369pkyZ4urTq1cvSdJ9992nhIQE+fv76+9//7uGDx+uQYMG6ZFHHtFPP/2kuLg4bd26Ve+++66mTZumDh06uNa/9dZbtWTJEuXm5rqmp/r+++8lVe7r9JCQEF1zzTV65plnVFRUpJYtW2rr1q06cuRINbwqpcXFxSkxMVHLly/X6dOnNWDAAH3xxRd67bXXNGLECA0aNKjKY/fs2VMdO3bUI488IofD4dHX/uctWLBAO3bsUN++fTVp0iR17dpVJ0+e1P79+/XRRx/p5MmTkqSbbrpJ77zzjv76179
|
||
|
|
"text/plain": [
|
||
|
|
"<Figure size 800x600 with 2 Axes>"
|
||
|
|
]
|
||
|
|
},
|
||
|
|
"metadata": {},
|
||
|
|
"output_type": "display_data"
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"data": {
|
||
|
|
"text/html": [
|
||
|
|
"<div>\n",
|
||
|
|
"<style scoped>\n",
|
||
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
||
|
|
" vertical-align: middle;\n",
|
||
|
|
" }\n",
|
||
|
|
"\n",
|
||
|
|
" .dataframe tbody tr th {\n",
|
||
|
|
" vertical-align: top;\n",
|
||
|
|
" }\n",
|
||
|
|
"\n",
|
||
|
|
" .dataframe thead th {\n",
|
||
|
|
" text-align: right;\n",
|
||
|
|
" }\n",
|
||
|
|
"</style>\n",
|
||
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||
|
|
" <thead>\n",
|
||
|
|
" <tr style=\"text-align: right;\">\n",
|
||
|
|
" <th></th>\n",
|
||
|
|
" <th>mean</th>\n",
|
||
|
|
" <th>sd</th>\n",
|
||
|
|
" <th>hdi_3%</th>\n",
|
||
|
|
" <th>hdi_97%</th>\n",
|
||
|
|
" <th>mcse_mean</th>\n",
|
||
|
|
" <th>mcse_sd</th>\n",
|
||
|
|
" <th>ess_bulk</th>\n",
|
||
|
|
" <th>ess_tail</th>\n",
|
||
|
|
" <th>r_hat</th>\n",
|
||
|
|
" </tr>\n",
|
||
|
|
" </thead>\n",
|
||
|
|
" <tbody>\n",
|
||
|
|
" <tr>\n",
|
||
|
|
" <th>v_easy</th>\n",
|
||
|
|
" <td>2.473</td>\n",
|
||
|
|
" <td>0.259</td>\n",
|
||
|
|
" <td>1.986</td>\n",
|
||
|
|
" <td>2.956</td>\n",
|
||
|
|
" <td>0.004</td>\n",
|
||
|
|
" <td>0.003</td>\n",
|
||
|
|
" <td>4368.0</td>\n",
|
||
|
|
" <td>4627.0</td>\n",
|
||
|
|
" <td>1.0</td>\n",
|
||
|
|
" </tr>\n",
|
||
|
|
" <tr>\n",
|
||
|
|
" <th>v_hard</th>\n",
|
||
|
|
" <td>0.281</td>\n",
|
||
|
|
" <td>0.176</td>\n",
|
||
|
|
" <td>0.000</td>\n",
|
||
|
|
" <td>0.580</td>\n",
|
||
|
|
" <td>0.003</td>\n",
|
||
|
|
" <td>0.002</td>\n",
|
||
|
|
" <td>3189.0</td>\n",
|
||
|
|
" <td>2151.0</td>\n",
|
||
|
|
" <td>1.0</td>\n",
|
||
|
|
" </tr>\n",
|
||
|
|
" <tr>\n",
|
||
|
|
" <th>a</th>\n",
|
||
|
|
" <td>0.846</td>\n",
|
||
|
|
" <td>0.026</td>\n",
|
||
|
|
" <td>0.798</td>\n",
|
||
|
|
" <td>0.896</td>\n",
|
||
|
|
" <td>0.000</td>\n",
|
||
|
|
" <td>0.000</td>\n",
|
||
|
|
" <td>4027.0</td>\n",
|
||
|
|
" <td>5078.0</td>\n",
|
||
|
|
" <td>1.0</td>\n",
|
||
|
|
" </tr>\n",
|
||
|
|
" <tr>\n",
|
||
|
|
" <th>beta</th>\n",
|
||
|
|
" <td>0.525</td>\n",
|
||
|
|
" <td>0.020</td>\n",
|
||
|
|
" <td>0.485</td>\n",
|
||
|
|
" <td>0.562</td>\n",
|
||
|
|
" <td>0.000</td>\n",
|
||
|
|
" <td>0.000</td>\n",
|
||
|
|
" <td>3813.0</td>\n",
|
||
|
|
" <td>4286.0</td>\n",
|
||
|
|
" <td>1.0</td>\n",
|
||
|
|
" </tr>\n",
|
||
|
|
" <tr>\n",
|
||
|
|
" <th>tau</th>\n",
|
||
|
|
" <td>0.397</td>\n",
|
||
|
|
" <td>0.003</td>\n",
|
||
|
|
" <td>0.392</td>\n",
|
||
|
|
" <td>0.402</td>\n",
|
||
|
|
" <td>0.000</td>\n",
|
||
|
|
" <td>0.000</td>\n",
|
||
|
|
" <td>3828.0</td>\n",
|
||
|
|
" <td>4132.0</td>\n",
|
||
|
|
" <td>1.0</td>\n",
|
||
|
|
" </tr>\n",
|
||
|
|
" </tbody>\n",
|
||
|
|
"</table>\n",
|
||
|
|
"</div>"
|
||
|
|
],
|
||
|
|
"text/plain": [
|
||
|
|
" mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail \\\n",
|
||
|
|
"v_easy 2.473 0.259 1.986 2.956 0.004 0.003 4368.0 4627.0 \n",
|
||
|
|
"v_hard 0.281 0.176 0.000 0.580 0.003 0.002 3189.0 2151.0 \n",
|
||
|
|
"a 0.846 0.026 0.798 0.896 0.000 0.000 4027.0 5078.0 \n",
|
||
|
|
"beta 0.525 0.020 0.485 0.562 0.000 0.000 3813.0 4286.0 \n",
|
||
|
|
"tau 0.397 0.003 0.392 0.402 0.000 0.000 3828.0 4132.0 \n",
|
||
|
|
"\n",
|
||
|
|
" r_hat \n",
|
||
|
|
"v_easy 1.0 \n",
|
||
|
|
"v_hard 1.0 \n",
|
||
|
|
"a 1.0 \n",
|
||
|
|
"beta 1.0 \n",
|
||
|
|
"tau 1.0 "
|
||
|
|
]
|
||
|
|
},
|
||
|
|
"execution_count": 15,
|
||
|
|
"metadata": {},
|
||
|
|
"output_type": "execute_result"
|
||
|
|
}
|
||
|
|
],
|
||
|
|
"source": [
|
||
|
|
"# print and graph in a different cell because the stan takes 5 to 10 business centuries to run\n",
|
||
|
|
"\n",
|
||
|
|
"v_easy = fit[\"v_easy\"]\n",
|
||
|
|
"v_hard = fit[\"v_hard\"]\n",
|
||
|
|
"a = fit[\"a\"]\n",
|
||
|
|
"beta = fit[\"beta\"]\n",
|
||
|
|
"tau = fit[\"tau\"]\n",
|
||
|
|
"\n",
|
||
|
|
"fig = az.plot_trace(fit)\n",
|
||
|
|
"plt.subplots_adjust(hspace=0.5)\n",
|
||
|
|
"plt.savefig('trace_plot.png')\n",
|
||
|
|
"plt.show()\n",
|
||
|
|
"\n",
|
||
|
|
"print(\n",
|
||
|
|
" f\"v_easy: mean = {np.mean(v_easy):.2f}, 95% HDI = {np.percentile(v_easy, [2.5, 97.5])}\")\n",
|
||
|
|
"print(\n",
|
||
|
|
" f\"v_hard: mean = {np.mean(v_hard):.2f}, 95% HDI = {np.percentile(v_hard, [2.5, 97.5])}\")\n",
|
||
|
|
"\n",
|
||
|
|
"\n",
|
||
|
|
"fig, axes = plt.subplots(2, 1, figsize=(8, 6)) # 2 subplots\n",
|
||
|
|
"axes[0].hist(v_easy.flatten(), bins=30, density=True,\n",
|
||
|
|
" alpha=0.7, color='blue', edgecolor='black')\n",
|
||
|
|
"axes[0].set_title(\"Histogram of v_easy\")\n",
|
||
|
|
"\n",
|
||
|
|
"axes[1].hist(v_hard.flatten(), bins=30, density=True,\n",
|
||
|
|
" alpha=0.7, color='red', edgecolor='black')\n",
|
||
|
|
"axes[1].set_title(\"Histogram of v_hard\")\n",
|
||
|
|
"\n",
|
||
|
|
"plt.subplots_adjust(hspace=0.5)\n",
|
||
|
|
"plt.savefig(\"drift_rate_histograms.png\", bbox_inches='tight')\n",
|
||
|
|
"plt.show()\n",
|
||
|
|
"\n",
|
||
|
|
"az.summary(fit)"
|
||
|
|
]
|
||
|
|
}
|
||
|
|
],
|
||
|
|
"metadata": {
|
||
|
|
"kernelspec": {
|
||
|
|
"display_name": ".venv",
|
||
|
|
"language": "python",
|
||
|
|
"name": "python3"
|
||
|
|
},
|
||
|
|
"language_info": {
|
||
|
|
"codemirror_mode": {
|
||
|
|
"name": "ipython",
|
||
|
|
"version": 3
|
||
|
|
},
|
||
|
|
"file_extension": ".py",
|
||
|
|
"mimetype": "text/x-python",
|
||
|
|
"name": "python",
|
||
|
|
"nbconvert_exporter": "python",
|
||
|
|
"pygments_lexer": "ipython3",
|
||
|
|
"version": "3.12.9"
|
||
|
|
}
|
||
|
|
},
|
||
|
|
"nbformat": 4,
|
||
|
|
"nbformat_minor": 2
|
||
|
|
}
|