123 lines
3.5 KiB
R
123 lines
3.5 KiB
R
# TODO: organize this file better bc I just kinda dumped everything in here
|
||
|
||
|
||
suppressPackageStartupMessages({
|
||
pkgs <- c("tidyverse", "readr", "readxl", "broom", "jsonlite", "ggplot2", "class", "optparse")
|
||
to_install <- pkgs[!pkgs %in% rownames(installed.packages())]
|
||
if (length(to_install)) install.packages(to_install, repos = "https://cloud.r-project.org")
|
||
lapply(pkgs, library, character.only = TRUE)
|
||
})
|
||
|
||
sanitize <- function(x) {
|
||
gsub("[^A-Za-z0-9_.-]+", "_", x)
|
||
}
|
||
|
||
save_plot <- function(p, path, w = 7, h = 5, dpi = 160) {
|
||
dir.create(dirname(path), recursive = TRUE, showWarnings = FALSE)
|
||
ggplot2::ggsave(path, p, width = w, height = h, dpi = dpi)
|
||
}
|
||
|
||
hist_density_plot <- function(v, lbl) {
|
||
tibble(x = v) |>
|
||
ggplot(aes(x = x)) +
|
||
geom_histogram(aes(y = after_stat(density)), bins = 30, alpha = 0.6) +
|
||
geom_density() +
|
||
labs(title = paste("histogram + density:", lbl), x = lbl, y = "density") +
|
||
theme_minimal()
|
||
}
|
||
|
||
box_plot <- function(v, lbl) {
|
||
tibble(x = v) |>
|
||
ggplot(aes(y = x)) +
|
||
geom_boxplot(width = 0.3) +
|
||
labs(title = paste("boxplot:", lbl), y = lbl, x = NULL) +
|
||
theme_minimal()
|
||
}
|
||
|
||
qq_two_sample <- function(a, b, title = "qq plot", n_q = NULL) {
|
||
a <- a[is.finite(a)]
|
||
b <- b[is.finite(b)]
|
||
n <- min(length(a), length(b))
|
||
|
||
if (is.null(n_q)) n_q <- max(10, floor(n))
|
||
if (n_q < 10) return(NULL)
|
||
|
||
probs <- seq(0.01, 0.99, length.out = n_q)
|
||
qa <- quantile(a, probs, na.rm = TRUE, names = FALSE)
|
||
qb <- quantile(b, probs, na.rm = TRUE, names = FALSE)
|
||
d <- tibble(x = sort(qa), y = sort(qb))
|
||
|
||
ggplot(d, aes(x, y)) +
|
||
geom_point(size = 1.6) +
|
||
geom_abline(slope = 1, intercept = 0) +
|
||
labs(title = title, x = "region a quantiles", y = "region b quantiles") +
|
||
theme_minimal()
|
||
}
|
||
|
||
tf_pos <- function(s) {
|
||
s <- as.numeric(s)
|
||
if (all(s[is.finite(s)] > 0)) log1p(s) else s
|
||
}
|
||
|
||
strat_split <- function(d, label_col, test_prop = 0.25) {
|
||
d <- d |> tidyr::drop_na({{label_col}})
|
||
idx_tr <- integer(0)
|
||
idx_te <- integer(0)
|
||
|
||
for (lev in unique(d[[label_col]])) {
|
||
rows <- which(d[[label_col]] == lev)
|
||
nte <- max(1, floor(length(rows) * test_prop))
|
||
|
||
te <- sample(rows, nte)
|
||
tr <- setdiff(rows, te)
|
||
|
||
idx_tr <- c(idx_tr, tr)
|
||
idx_te <- c(idx_te, te)
|
||
}
|
||
|
||
list(train = d[idx_tr, , drop = FALSE], test = d[idx_te, , drop = FALSE])
|
||
}
|
||
|
||
run_knn <- function(d, label_col, vars, k, tag, fig_dir) {
|
||
use <- d |> dplyr::select(all_of(c(label_col, vars))) |> tidyr::drop_na()
|
||
split <- strat_split(use, label_col, 0.25)
|
||
tr <- split$train
|
||
te <- split$test
|
||
|
||
mu <- sapply(tr[, vars, drop = FALSE], mean, na.rm = TRUE)
|
||
sdv <- sapply(tr[, vars, drop = FALSE], sd, na.rm = TRUE)
|
||
sdv[sdv == 0] <- 1
|
||
|
||
trX <- scale(as.matrix(tr[, vars, drop = FALSE]), center = mu, scale = sdv)
|
||
teX <- scale(as.matrix(te[, vars, drop = FALSE]), center = mu, scale = sdv)
|
||
|
||
trY <- factor(tr[[label_col]])
|
||
teY <- factor(te[[label_col]], levels = levels(trY))
|
||
|
||
pred <- class::knn(train = trX, test = teX, cl = trY, k = k)
|
||
acc <- mean(pred == teY, na.rm = TRUE)
|
||
|
||
cm <- table(truth = teY, pred = pred)
|
||
cm_df <- as.data.frame(cm)
|
||
cm_fig <- file.path(fig_dir, paste0("knn_confusion_", sanitize(tag), ".png"))
|
||
|
||
p_cm <- ggplot(cm_df, aes(x = pred, y = truth, fill = Freq)) +
|
||
geom_tile() + geom_text(aes(label = Freq)) +
|
||
labs(
|
||
title = paste0("confusion matrix (k=", k, ") – ", tag),
|
||
x = "predicted", y = "true"
|
||
) + theme_minimal() +
|
||
theme(
|
||
axis.text.x = element_text(angle = 90, vjust = 0.5, hjust = 1)
|
||
)
|
||
|
||
save_plot(p_cm, cm_fig, w = 6, h = 5)
|
||
|
||
list(tag = tag,
|
||
k = k,
|
||
vars = vars,
|
||
accuracy = unname(acc),
|
||
confusion_fig = cm_fig,
|
||
n_test = nrow(te))
|
||
}
|