# KNN Classification

In [None]:
library("caret")
?knn3


In [None]:
library("palmerpenguins")
penguins <- penguins[complete.cases(penguins), ]
head(penguins)


In [None]:
summary(penguins)


In [None]:
dim(penguins)


In [None]:
table(penguins$species) / nrow(penguins)


In [None]:
class(penguins$species)


# Fit the model

In [None]:
mod <- knn3(species ~ ., data = penguins, k = 5)


In [None]:
preds <- predict(mod, newdata = penguins, type = "class")
head(preds, n = 10)


In [None]:
head(predict(mod, newdata = penguins, type = "prob"))


In [None]:
cm <- confusionMatrix(data = preds, reference = penguins$species)
cm


In [None]:
C <- as.matrix(cm$table)
C


In [None]:
sum(diag(C)) / sum(C)


In [None]:
C[1, 1] / sum(C[, 1])


In [None]:
C[2, 2] / sum(C[, 2])


# Plot the prediction space

In [None]:
library("ggplot2")
library("dplyr")
library("tidyr")
library("purrr")


In [None]:
v1 <- "bill_length_mm"
v2 <- "bill_depth_mm"


In [None]:
plot_fit <- function(v1, v2, df = penguins, N = floor(sqrt(10000)), k = 10) {
    train_df <- df %>% select(all_of(c("species", v1, v2)))

    mod <- knn3(species ~ ., data = train_df, k = k)

    combinations <- expand_grid(!!!map(train_df %>% select(-species), ~ seq(min(.x), max(.x), length.out = N)))
    colnames(combinations) <- c(v1, v2)
    preds <- predict(mod, newdata = combinations, type = "class")
    combinations$species <- preds

    ggplot(data = combinations, mapping = aes(x = !!sym(v1), y = !!sym(v2), fill = species, shape = species)) +
        geom_tile() +
        geom_point(data = train_df, size = 5)
}


In [None]:
plot_fit(v1 = "bill_length_mm", v2 = "bill_depth_mm", df = penguins, k = 50)


In [None]:
plot_fit(v1 = "bill_length_mm", v2 = "bill_depth_mm", df = penguins, k = 10)


In [None]:
plot_fit(v1 = "bill_length_mm", v2 = "bill_depth_mm", df = penguins, k = 1)


In [None]:
plot_fit(v1 = "bill_length_mm", v2 = "bill_depth_mm", df = penguins, k = 1) + coord_fixed()


In [None]:
plot_fit <- function(v1, v2, df = penguins, N = floor(sqrt(10000)), k = 10) {
    train_df <- df %>% select(all_of(c("species", v1, v2)))
    train_df <- train_df %>% mutate(across(-species,~(.x-mean(.x))/sd(.x)))

    mod <- knn3(species ~ ., data = train_df, k = k)

    combinations <- expand_grid(!!!map(train_df %>% select(-species), ~ seq(min(.x), max(.x), length.out = N)))
    colnames(combinations) <- c(v1, v2)
    preds <- predict(mod, newdata = combinations, type = "class")
    combinations$species <- preds

    ggplot(data = combinations, mapping = aes(x = !!sym(v1), y = !!sym(v2), fill = species, shape = species)) +
        geom_tile() +
        geom_point(data = train_df, size = 5) + 
        coord_fixed()
}


In [None]:
plot_fit(v1 = "bill_length_mm", v2 = "bill_depth_mm", df = penguins, k = 1)


In [None]:
plot_fit(v1 = "flipper_length_mm", v2 = "bill_depth_mm", df = penguins, k = 1)


In [None]:
plot_fit(v1 = "flipper_length_mm", v2 = "bill_depth_mm", df = penguins, k = 50)


# Optimal tuning

In [None]:
library('recipes')
library('rsample')
library('yardstick')

In [None]:
rec_obj <- recipe(species~flipper_length_mm+bill_depth_mm,data=penguins) %>%
    step_center(all_numeric_predictors()) %>%
    step_scale(all_numeric_predictors())
rec_obj

In [None]:
splts <- rsample::vfold_cv(penguins,v=10)

In [None]:
splts$splits[[1]]$id$id

In [None]:
fit_fn_K <- function(splt,K=NULL){
    trained_rec <- prep(rec_obj,training=training(splt))
    train_data <- bake(trained_rec,training(splt))
    test_data <- bake(trained_rec,testing(splt))

    mod <- knn3(species ~ ., data = train_data, k = K)
    train_data$est <- predict(mod,train_data,type='class')
    train_data$type <- 'train'
    test_data$est <- predict(mod,test_data,type='class')
    test_data$type <- 'test'
    df <- bind_rows(train_data,test_data)
    df$K <- K
    df$id <- splt$id$id
    return(df)
}

In [None]:
K_seq <- seq(1, 100, length.out = 100)


In [None]:
cmbs <- crossing(splt = splts$splits,K = K_seq)

In [None]:

res <- pmap(cmbs,~fit_fn_K(splt=..1,K=..2))

In [None]:
res_all <- res %>% bind_rows()
res_all %>% sample_n(10)

In [None]:
acc_tbl <- res_all %>% group_by(type,K,id) %>% accuracy(species,est) %>% ungroup()
acc_tbl %>% sample_n(10)

In [None]:
acc_smry <- acc_tbl %>%
    group_by(type, K) %>%
    summarize(
        mean = mean(.estimate),
        q25 = quantile(.estimate, .25),
        q75 = quantile(.estimate, .75)
    ) %>% ungroup()
acc_smry %>% sample_n(10)

In [None]:
library("ggplot2")
ggplot(data = acc_smry, mapping = aes(x = K, y = mean, color = type, fill=type)) +
    geom_ribbon(mapping=aes(ymin = q25, ymax = q75),alpha=.25) +
    geom_point() +
    scale_x_log10()


In [None]:
K_hat <- acc_smry %>%
    filter(type == "test") %>%
    filter(mean == max(mean)) %>%
    pull(K)
K_hat 


In [None]:
plot_fit(v1 = "bill_length_mm", v2 = "bill_depth_mm", df = penguins, k = K_hat)
