Files

220 lines
28 KiB
Plaintext
Raw Permalink Normal View History

2025-04-14 22:31:56 -04:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "2d616210",
"metadata": {},
"outputs": [],
"source": [
"# DATA\n",
"\n",
"import pandas as pd\n",
"import stan\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"import arviz as az\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"\n",
"# stan problems\n",
"import nest_asyncio\n",
"nest_asyncio.apply()\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c5cdedfd",
"metadata": {},
"outputs": [],
"source": [
"url = \"https://raw.githubusercontent.com/stedy/Machine-Learning-with-R-datasets/master/insurance.csv\"\n",
"data = pd.read_csv(url)\n",
"\n",
"# god is dead, force a normal\n",
"data['smoker'] = data['smoker'].map({'yes': 1, 'no': 0})\n",
"\n",
"X = data[['bmi', 'age', 'children', 'smoker']].values\n",
"y = data['charges'].values\n",
"\n",
"\n",
"# Standardize predictors (mean=0, sd=1)\n",
"X_mean = X.mean(axis=0)\n",
"X_std = X.std(axis=0)\n",
"y_mean = y.mean()\n",
"y_std = y.std()\n",
"y = (y - y_mean) / y_std\n",
"X_std = X.std(axis=0)\n",
"continuous_vars = data[['bmi', 'age', 'children']]\n",
"continuous_vars_standardized = (continuous_vars - continuous_vars.mean()) / continuous_vars.std()\n",
"\n",
"X_standardized = np.hstack([continuous_vars_standardized, data[['smoker']].values])\n",
"\n",
"# Split into train/test\n",
"X_train, X_test, y_train, y_test = train_test_split(X_standardized, y, test_size=0.2, random_state=42)\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1a832a9e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Building...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"Building: found in cache, done."
]
}
],
"source": [
"stan_data = {\n",
" 'N': X_train.shape[0],\n",
" 'M': X_train.shape[1],\n",
" 'X': X_train,\n",
" 'y': y_train,\n",
" 'N_test': X_test.shape[0],\n",
" 'X_test': X_test\n",
"}\n",
"\n",
"with open('multiple_regression.stan', 'r') as f:\n",
" stan_code = f.read()\n",
" \n",
"posterior = stan.build(stan_code, data=stan_data, random_seed=42)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "def536a9",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling: 0%\n",
"Sampling: 25% (3000/12000)\n",
"Sampling: 50% (6000/12000)\n",
"Sampling: 75% (9000/12000)\n",
"Sampling: 100% (12000/12000)\n",
"Sampling: 100% (12000/12000), done.\n",
"Messages received during sampling:\n",
" Gradient evaluation took 0.00019 seconds\n",
" 1000 transitions using 10 leapfrog steps per transition would take 1.9 seconds.\n",
" Adjust your expectations accordingly!\n",
" Informational Message: The current Metropolis proposal is about to be rejected because of the following issue:\n",
" Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in '/tmp/httpstan_gquumxii/model_rezrxovk.stan', line 21, column 2 to column 38)\n",
" If this warning occurs sporadically, such as for highly constrained variable types like covariance matrices, then the sampler is fine,\n",
" but if this warning occurs often then your model may be either severely ill-conditioned or misspecified.\n",
" Gradient evaluation took 0.000184 seconds\n",
" 1000 transitions using 10 leapfrog steps per transition would take 1.84 seconds.\n",
" Adjust your expectations accordingly!\n",
" Gradient evaluation took 0.000168 seconds\n",
" 1000 transitions using 10 leapfrog steps per transition would take 1.68 seconds.\n",
" Adjust your expectations accordingly!\n",
" Gradient evaluation took 0.000129 seconds\n",
" 1000 transitions using 10 leapfrog steps per transition would take 1.29 seconds.\n",
" Adjust your expectations accordingly!\n",
" Informational Message: The current Metropolis proposal is about to be rejected because of the following issue:\n",
" Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in '/tmp/httpstan_gquumxii/model_rezrxovk.stan', line 21, column 2 to column 38)\n",
" If this warning occurs sporadically, such as for highly constrained variable types like covariance matrices, then the sampler is fine,\n",
" but if this warning occurs often then your model may be either severely ill-conditioned or misspecified.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n",
"alpha -0.396 0.017 -0.430 -0.364 0.0 0.0 9346.0 \n",
"beta[0] 0.165 0.016 0.136 0.195 0.0 0.0 10163.0 \n",
"beta[1] 0.298 0.016 0.269 0.328 0.0 0.0 10522.0 \n",
"beta[2] 0.042 0.015 0.014 0.072 0.0 0.0 10397.0 \n",
"beta[3] 1.951 0.039 1.875 2.020 0.0 0.0 9719.0 \n",
"sigma 0.507 0.011 0.486 0.527 0.0 0.0 9713.0 \n",
"\n",
" ess_tail r_hat \n",
"alpha 6429.0 1.0 \n",
"beta[0] 6037.0 1.0 \n",
"beta[1] 5893.0 1.0 \n",
"beta[2] 5565.0 1.0 \n",
"beta[3] 6633.0 1.0 \n",
"sigma 6303.0 1.0 \n"
]
},
{
"data": {
"text/plain": [
"Text(0.5, 1.0, 'Posterior Distributions of Beta Coefficients')"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjUAAAGZCAYAAABxI8CQAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAAPYQAAD2EBqD+naQAAPrhJREFUeJzt3XlclOX+//H3gAyrigbmUoK4ZKJlLmm550JmpYW5VKb2Fc0W86SlVh6wc4xj6dGyjpWmcigtl7RTqWgpaqappbZZoSm2uGC4Ysoy1+8Pf4yMLAIiAzev5+PBg5l77rnvz8w9y3uu+7qv22aMMQIAACjnPNxdAAAAQEkg1AAAAEsg1AAAAEsg1AAAAEsg1AAAAEsg1AAAAEsg1AAAAEsg1AAAAEsg1AAAAEsg1AAFmD9/vmw2m/bv3+/uUvI0ZMgQhYaGlsq6QkNDNWTIEOf17Odm+/btpbL+zp07q3PnzqWyritp27ZtuvXWW+Xv7y+bzaadO3e6uyRLyMzM1DPPPKNrr71WHh4e6tOnjyTp9OnTGjZsmGrWrCmbzabRo0dr//79stlsmj9/fpHWUdY/D0CoQQnLftNn//n4+KhRo0Z6/PHHdfjw4RJf35kzZxQTE6PExMQSX3Zpi4mJcXnu/Pz8VLduXd11112aN2+ezp07VyLr+eGHHxQTE1MmP5jLcm0lISMjQ/fdd59SU1M1ffp0xcfHKyQkJM95ExMTXV4PNptN1atXV9u2bfXuu+8Wu4YFCxZoxowZxb7/pSxbtkw9e/ZUUFCQ7Ha7ateurX79+mnt2rVXbJ2SNHfuXL388svq27ev4uLi9Le//U2S9OKLL2r+/PkaOXKk4uPjNWjQoCtax+Wy0meaWxigBM2bN89IMi+88IKJj483s2fPNoMHDzYeHh6mXr16Ji0trUTXl5KSYiSZ6OjoEl1utszMTPPXX38Zh8NxRZafU3R0tJFkZs2aZeLj482cOXPMpEmTzK233mokmRtuuMEcOHDA5T7p6enm7NmzRVrP4sWLjSSzbt26It3v7NmzJj093Xk9e1tv27atSMspbm3nzp0z586dK7F1ucPu3buNJDN79uxLzrtu3TojyYwaNcrEx8eb+Ph4M2PGDHPLLbcYSea1114rVg29evUyISEhxbpvQRwOhxkyZIiRZG666SYzefJk8/bbb5t//vOfpmXLlkaS2bRpU4mvN1v//v1NnTp1ck1v06aNadeuXa5a//rrL5OZmVmkdZTG58GV/kyzukpuSVKwvJ49e6pVq1aSpGHDhumqq67Sv//9b3344YcaOHCgm6u7tLS0NPn7+8vT01Oenp4lttwzZ87Iz8+vwHn69u2roKAg5/W///3vevfdd/XQQw/pvvvu05YtW5y3eXl5lVhteTHG6OzZs/L19ZW3t/cVXdel2O12t66/JBw5ckSSFBgYWOj7dOjQQX379nVeHzlypMLCwrRgwQI99thjJV1isU2bNk3z58/X6NGj9e9//1s2m81523PPPaf4+HhVqnTlvnKOHDmS5/N65MgRNWnSxGVadityUZX05wGuAHenKlhLfr/eP/74YyPJTJ482RhjTEZGhnnhhRdMWFiYsdvtJiQkxEyYMCFXq8O2bdtMjx49zFVXXWV8fHxMaGioGTp0qDHGmH379hlJuf5y/sLZvXu3iYyMNNWqVTPe3t6mZcuW5sMPP8yz5sTERDNy5EgTHBxsAgMDXW7bt2+fy31ef/1106RJE2O3202tWrXMo48+ao4dO+YyT6dOnUx4eLjZvn276dChg/H19TVPPvlkvs9ddktNSkpKnrcPHz7cSDKrV692Ths8eHCuX90LFy40LVq0MAEBAaZy5cqmadOmZsaMGS6P5+K/7JaRkJAQ06tXL7Nq1SrTsmVL4+3tbaZPn+68bfDgwbmet/Xr15vhw4eb6tWrm8qVK5tBgwaZ1NRUl5ou3i7Zci7zUrV16tTJdOrUyeX+hw8fNg8//LCpUaOG8fb2NjfccIOZP3++yzzZr5OXX37ZvPnmm87XXKtWrczWrVtd5j148KAZMmSIqVOnjrHb7aZmzZrm7rvvzrX98/LZZ5+Z9u3bGz8/P1O1alVz9913mx9++MF5++DBg3M9tosfT07ZLTWLFy/OdVvTpk1Nx44dc02Pj483LVq0MD4+PqZatWqmf//+Lq17nTp1ylVD9uvn3LlzZuLEiaZFixamSpUqxs/Pz7Rv396sXbv2ko/9zJkzpnr16qZx48aFbv3Yu3ev6du3r6lWrZrx9fU1bdq0MR9//HGu+c6ePWv+/ve/m/r16xu73W6uueYa8/TTTzs/K/L7HMh+/i7+27dvn/M+8+bNc1nX7t27zX333WeCgoKMj4+PadSokXn22Wedt+f3ebBixQrntg8ICDB33HGH+e6771zmGTx4sPH39ze//fab6d27t/H39zdBQUFmzJgxzufsUp9pl/P6rChoqUGp2Lt3ryTpqquuknS+9SYuLk59+/bVmDFj9OWXXyo2Nla7d+/WsmXLJJ3/hdWjRw8FBwdr/PjxCgwM1P79+/XBBx9IkoKDgzVr1iyNHDlS99xzj+69915J0g033CBJ+v7779WuXTvVqVNH48ePl7+/vxYtWqQ+ffpo6dKluueee1xqfPTRRxUcHKy///3vSktLy/exxMTEaNKkSerWrZtGjhypn376SbNmzdK2bdu0adMml9aTP//8Uz179tSAAQP04IMP6uqrry72czho0CC99dZbWr16tbp3757nPGvWrNHAgQPVtWtXTZkyRZK0e/dubdq0SU8++aQ6duyoUaNG6dVXX9Wzzz6r66+/XpKc/yXpp59+0sCBAzVixAhFRUXpuuuuK7Cuxx9/XIGBgYqJiXE+F8nJyc4+IYVVmNpy+uuvv9S5c2ft2bNHjz/+uOrVq6fFixdryJAhOn78uJ588kmX+RcsWKBTp05pxIgRstlseumll3Tvvffql19+cW6zyMhIff/993riiScUGhqqI0eOaM2aNTpw4ECBHbI//fRT9ezZU2FhYYqJidFff/2lmTNnql27dvr6668VGhqqESNGqE6dOnrxxRc1atQotW7dulCvh1OnTuno0aOSpNTUVC1YsEDfffed3n77bZf5Jk+erIkTJ6pfv34aNmyYUlJSNHPmTHXs2FE7duxQYGCgnnvuOZ04cUK//fabpk+fLkkKCAiQJJ08eVJz5szRwIEDFRUVpVOnTuntt99WRESEtm7dqubNm+db4+eff67U1FSNHj26UC0Zhw8f1q233qozZ85o1KhRuuqqqxQXF6e7775bS5Yscb43HQ6H7r77bn3++ecaPny4rr/+en377beaPn26fv75Zy1fvlzBwcGKj4/X5MmTdfr0acXGxko6/7qJj4/X3/72N11zzTUaM2aMpPOfGykpKblq+uabb9ShQwd5eXlp+PDhCg0N1d69e/XRRx9p8uTJ+T6W+Ph4DR48WBEREZoyZYrOnDmjWbNmqX379tqxY4fL6yYrK0sRERFq06aNpk6dqk8//VTTpk1T/fr1NXLkyEt+phX39VmhuDtVwVqyf8l8+umnJiUlxfz666/mvffeM1dddZXx9fU1v/32m9m5c6eRZIYNG+Zy37FjxxpJzl+Gy5Ytu2SfjYL2P3ft2tU0a9bMpfXH4XCYW2+91TRs2DBXze3bt8/1K/PiX2ZHjhwxdrvd9OjRw2RlZTnne+2114wkM3fuXOe07F/Fb7zxxqWfOHPplppjx44ZSeaee+5xTru4pebJJ580VapUKfDXckH9VkJCQowks2rVqjxvy6ulpmXLli59bV566SUjyaVFLL9tdPEyC6rt4paaGTNmGEnmnXfecU5LT083t9xyiwkICDAnT540xlz49XvVVVe5tCB9+OGHRpL56KOPjDEXnt+XX34517ovpXnz5qZGjRrmzz//dE7btWuX8fDwMA899JBzWkGtLxfLr6XBw8PD2eKZbf/+/cbT0zPX9G+//dZUqlTJZXp+fWoyMzNz9Vk6duyYufrqq83DDz9cYK2vvPKKkWSWLVt2ycdljDGjR482kszGjRu
"text/plain": [
"<Figure size 600x440 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"\n",
"fit = posterior.sample(num_chains=4, num_samples=2000)\n",
"y_pred_samples = fit['y_pred']\n",
"\n",
"# shape --> [num_draws, N_test]\n",
"# Adjust y_test to match y_pred_samples shape (268, 8000) and compute RMSE per draw <-- Chatgpt did this because everything kep breaking, help\n",
"rmse_samples = np.sqrt(np.mean((y_pred_samples - y_test[:, None])**2, axis=0))\n",
"rmse_mean = rmse_samples.mean()\n",
"rmse_ci = np.percentile(rmse_samples, [2.5, 97.5])\n",
"\n",
"idata = az.from_pystan(posterior=fit)\n",
"print(az.summary(idata, var_names=['alpha', 'beta', 'sigma']))\n",
"\n",
"az.plot_forest(idata, var_names=['beta'], combined=True, colors=\"purple\")\n",
"plt.axvline(0, linestyle='--', color='gray')\n",
"plt.title(\"Posterior Distributions of Beta Coefficients\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}