123 lines
3.5 KiB
R
123 lines
3.5 KiB
R
|
|
# TODO: organize this file better bc I just kinda dumped everything in here
|
|||
|
|
|
|||
|
|
|
|||
|
|
suppressPackageStartupMessages({
|
|||
|
|
pkgs <- c("tidyverse", "readr", "readxl", "broom", "jsonlite", "ggplot2", "class", "optparse")
|
|||
|
|
to_install <- pkgs[!pkgs %in% rownames(installed.packages())]
|
|||
|
|
if (length(to_install)) install.packages(to_install, repos = "https://cloud.r-project.org")
|
|||
|
|
lapply(pkgs, library, character.only = TRUE)
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
sanitize <- function(x) {
|
|||
|
|
gsub("[^A-Za-z0-9_.-]+", "_", x)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
save_plot <- function(p, path, w = 7, h = 5, dpi = 160) {
|
|||
|
|
dir.create(dirname(path), recursive = TRUE, showWarnings = FALSE)
|
|||
|
|
ggplot2::ggsave(path, p, width = w, height = h, dpi = dpi)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
hist_density_plot <- function(v, lbl) {
|
|||
|
|
tibble(x = v) |>
|
|||
|
|
ggplot(aes(x = x)) +
|
|||
|
|
geom_histogram(aes(y = after_stat(density)), bins = 30, alpha = 0.6) +
|
|||
|
|
geom_density() +
|
|||
|
|
labs(title = paste("histogram + density:", lbl), x = lbl, y = "density") +
|
|||
|
|
theme_minimal()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
box_plot <- function(v, lbl) {
|
|||
|
|
tibble(x = v) |>
|
|||
|
|
ggplot(aes(y = x)) +
|
|||
|
|
geom_boxplot(width = 0.3) +
|
|||
|
|
labs(title = paste("boxplot:", lbl), y = lbl, x = NULL) +
|
|||
|
|
theme_minimal()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
qq_two_sample <- function(a, b, title = "qq plot", n_q = NULL) {
|
|||
|
|
a <- a[is.finite(a)]
|
|||
|
|
b <- b[is.finite(b)]
|
|||
|
|
n <- min(length(a), length(b))
|
|||
|
|
|
|||
|
|
if (is.null(n_q)) n_q <- max(10, floor(n))
|
|||
|
|
if (n_q < 10) return(NULL)
|
|||
|
|
|
|||
|
|
probs <- seq(0.01, 0.99, length.out = n_q)
|
|||
|
|
qa <- quantile(a, probs, na.rm = TRUE, names = FALSE)
|
|||
|
|
qb <- quantile(b, probs, na.rm = TRUE, names = FALSE)
|
|||
|
|
d <- tibble(x = sort(qa), y = sort(qb))
|
|||
|
|
|
|||
|
|
ggplot(d, aes(x, y)) +
|
|||
|
|
geom_point(size = 1.6) +
|
|||
|
|
geom_abline(slope = 1, intercept = 0) +
|
|||
|
|
labs(title = title, x = "region a quantiles", y = "region b quantiles") +
|
|||
|
|
theme_minimal()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
tf_pos <- function(s) {
|
|||
|
|
s <- as.numeric(s)
|
|||
|
|
if (all(s[is.finite(s)] > 0)) log1p(s) else s
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
strat_split <- function(d, label_col, test_prop = 0.25) {
|
|||
|
|
d <- d |> tidyr::drop_na({{label_col}})
|
|||
|
|
idx_tr <- integer(0)
|
|||
|
|
idx_te <- integer(0)
|
|||
|
|
|
|||
|
|
for (lev in unique(d[[label_col]])) {
|
|||
|
|
rows <- which(d[[label_col]] == lev)
|
|||
|
|
nte <- max(1, floor(length(rows) * test_prop))
|
|||
|
|
|
|||
|
|
te <- sample(rows, nte)
|
|||
|
|
tr <- setdiff(rows, te)
|
|||
|
|
|
|||
|
|
idx_tr <- c(idx_tr, tr)
|
|||
|
|
idx_te <- c(idx_te, te)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
list(train = d[idx_tr, , drop = FALSE], test = d[idx_te, , drop = FALSE])
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
run_knn <- function(d, label_col, vars, k, tag, fig_dir) {
|
|||
|
|
use <- d |> dplyr::select(all_of(c(label_col, vars))) |> tidyr::drop_na()
|
|||
|
|
split <- strat_split(use, label_col, 0.25)
|
|||
|
|
tr <- split$train
|
|||
|
|
te <- split$test
|
|||
|
|
|
|||
|
|
mu <- sapply(tr[, vars, drop = FALSE], mean, na.rm = TRUE)
|
|||
|
|
sdv <- sapply(tr[, vars, drop = FALSE], sd, na.rm = TRUE)
|
|||
|
|
sdv[sdv == 0] <- 1
|
|||
|
|
|
|||
|
|
trX <- scale(as.matrix(tr[, vars, drop = FALSE]), center = mu, scale = sdv)
|
|||
|
|
teX <- scale(as.matrix(te[, vars, drop = FALSE]), center = mu, scale = sdv)
|
|||
|
|
|
|||
|
|
trY <- factor(tr[[label_col]])
|
|||
|
|
teY <- factor(te[[label_col]], levels = levels(trY))
|
|||
|
|
|
|||
|
|
pred <- class::knn(train = trX, test = teX, cl = trY, k = k)
|
|||
|
|
acc <- mean(pred == teY, na.rm = TRUE)
|
|||
|
|
|
|||
|
|
cm <- table(truth = teY, pred = pred)
|
|||
|
|
cm_df <- as.data.frame(cm)
|
|||
|
|
cm_fig <- file.path(fig_dir, paste0("knn_confusion_", sanitize(tag), ".png"))
|
|||
|
|
|
|||
|
|
p_cm <- ggplot(cm_df, aes(x = pred, y = truth, fill = Freq)) +
|
|||
|
|
geom_tile() + geom_text(aes(label = Freq)) +
|
|||
|
|
labs(
|
|||
|
|
title = paste0("confusion matrix (k=", k, ") – ", tag),
|
|||
|
|
x = "predicted", y = "true"
|
|||
|
|
) + theme_minimal() +
|
|||
|
|
theme(
|
|||
|
|
axis.text.x = element_text(angle = 90, vjust = 0.5, hjust = 1)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
save_plot(p_cm, cm_fig, w = 6, h = 5)
|
|||
|
|
|
|||
|
|
list(tag = tag,
|
|||
|
|
k = k,
|
|||
|
|
vars = vars,
|
|||
|
|
accuracy = unname(acc),
|
|||
|
|
confusion_fig = cm_fig,
|
|||
|
|
n_test = nrow(te))
|
|||
|
|
}
|