In [None]:
library("ggplot2")
library("reshape2")
library("tibble")
library("dplyr")
library("purrr")
library("tidyr")

In [None]:
sig2 <- 1
mus <- c(1, 5, 20)
N <- 100

In [None]:
Ns <- rep(N, 3)
Ns

In [None]:
g1 <- data.frame(group = 1, x = rnorm(Ns[1], mus[1], 1))
g2 <- data.frame(group = 2, x = rnorm(Ns[2], mus[2], 1))
g3 <- data.frame(group = 3, x = rnorm(Ns[3], mus[3], 1))
d <- rbind(g1, g2, g3)
d$group <- as.factor(d$group)
d <- as_tibble(d[sample(nrow(d)), ])

In [None]:
d %>% head()

In [None]:
options(repr.plot.width = 10, repr.plot.height = 3, repr.plot.res = 100)
ggplot(data = d, mapping = aes(x = x, color = group, group = group)) +
  geom_density() +
  geom_point(mapping = aes(x = x, y = 0))

In [None]:
class <- 1

subd <- d %>% filter(group == class)
subd %>% head()

In [None]:
# mu1_hat
mu_hat <- d %>%
  summarize(mu_hat = mean(x)) %>%
  pull(mu_hat)
mu_hat

In [None]:
# pi_hat
pi_hat <- d %>%
  summarize(pi_hat = mean(group == class)) %>%
  pull(pi_hat)
pi_hat

In [None]:
# pooled sd
vars <- d %>%
  group_by(group) %>%
  summarize(var_hat = var(x)) %>%
  pull(var_hat)
vars

In [None]:
pooled_var <- sum((Ns - 1) * vars) / (sum(Ns) - 3)
pooled_var

In [None]:
var(d$x)

In [None]:
x0 <- 1

In [None]:
dnorm(x0, mean = mu_hat, sd = sqrt(pooled_var)) * pi_hat

In [None]:
mu_hats <- d %>%
  group_by(group) %>%
  summarize(mu_hat = mean(x)) %>%
  pull(mu_hat)
mu_hats

pi_hats <- map_vec(d %>% pull(group) %>% unique(), ~ d %>%
  summarize(pi_hat = mean(group == ..1)) %>%
  pull(pi_hat))
pi_hats

vars <- d %>%
  group_by(group) %>%
  summarize(var_hat = var(x)) %>%
  pull(var_hat)
pooled_var <- sum((Ns - 1) * vars) / (sum(Ns) - 3)
pooled_var

In [None]:
groups <- d %>%
  pull(group) %>%
  unique()
mu_hats[groups[3]]

In [None]:
delta_lda_c <- function(x0, class) {
  class <- as.integer(class)
  dlta <- dnorm(x0, mean = mu_hats[class], sd = sqrt(pooled_var)) * pi_hats[class]
  return(dlta)
}

In [None]:
x0 <- 1

In [None]:
delta_lda_c(x0, class = 1)

In [None]:
delta_lda_c(x0, class = 2)

In [None]:
delta_lda_c(x0, class = 3)

In [None]:
lda_pred <- function(x0) {
  deltas <- map(1:3, ~ delta_lda_c(x0, ..1))
  return(which.max(deltas))
}

In [None]:
lda_pred(x0 = 1)

In [None]:
lda_pred(x0 = 7)

In [None]:
lda_pred(x0 = 20)

In [None]:
x_seq <- seq(-2, 25, length.out = 500)

In [None]:
df <- tibble(x = x_seq, y_pred = map_vec(x_seq, lda_pred))
df <- df %>% mutate(y_pred = factor(y_pred))
df <- df %>% mutate(
  c1 = map_vec(x, ~ delta_lda_c(..1, class = 1)),
  c2 = map_vec(x, ~ delta_lda_c(..1, class = 2)),
  c3 = map_vec(x, ~ delta_lda_c(..1, class = 3))
)
head(df)

In [None]:
ggplot(data = d, mapping = aes(x = x, color = group, group = group)) +
  geom_density() +
  geom_point(mapping = aes(x = x, y = 0))

In [None]:
df_long <- df %>% pivot_longer(cols = c(c1, c2, c3))
df_long <- df_long %>% mutate(y_pred = paste0("c", y_pred))
df_long %>% head()

In [None]:
options(repr.plot.width = 10, repr.plot.height = 4, repr.plot.res = 100)
ggplot(data = df_long, mapping = aes(x = x, y = 0, color = y_pred, group = y_pred)) +
  geom_line(mapping = aes(x = x, y = value, group = name, color = name), lwd = 2, inherit.aes = FALSE) +
  geom_point()

In [None]:
library("MASS")
?lda

In [None]:
my_preds <- sapply(d$x, lda_pred)
my_preds

In [None]:
mod <- lda(group ~ ., data = d)

In [None]:
mod

In [None]:
mu_hat

In [None]:
pi_hat

In [None]:
mod_preds <- predict(mod)$class
mod_preds

In [None]:
all(mod_preds == my_preds)

In [None]:
head(predict(mod)$posterior)

In [None]:
mod_df <- bind_cols(tibble(x = x_seq), predict(mod, newdata = data.frame(x = x_seq))$posterior)
mod_df <- mod_df %>% pivot_longer(cols = 2:4)
mod_df <- mod_df %>% mutate(name = factor(name))
mod_df %>% head()

In [None]:
options(repr.plot.width = 10, repr.plot.height = 3, repr.plot.res = 100)
ggplot(data = mod_df, mapping = aes(x = x, y = value, group = name, color = name)) +
  geom_line(lwd = 2)

In [None]:
delta_lda_c2 <- function(x0, class) {
  class <- as.integer(class)
  dlta <- mu_hats[class] * x0 / (pooled_var) - mu_hats[class]^2 / (2 * pooled_var) + log(pi_hats[class])
  return(dlta)
}
lda_pred2 <- function(x0) {
  deltas <- map(1:3, ~ delta_lda_c2(x0, ..1))
  return(which.max(deltas))
}

In [None]:
df <- tibble(x = x_seq, y_pred = map_vec(x_seq, lda_pred2))
df <- df %>% mutate(y_pred = factor(y_pred))
df <- df %>% mutate(
  c1 = map_vec(x, ~ delta_lda_c2(..1, class = 1)),
  c2 = map_vec(x, ~ delta_lda_c2(..1, class = 2)),
  c3 = map_vec(x, ~ delta_lda_c2(..1, class = 3))
)
head(df)

In [None]:
options(repr.plot.width = 10, repr.plot.height = 4, repr.plot.res = 100)

df_long <- df %>% pivot_longer(cols = c(c1, c2, c3))
df_long <- df_long %>% mutate(y_pred = paste0("c", y_pred))

ggplot(data = df_long, mapping = aes(x = x, y = 0, color = y_pred, group = y_pred)) +
  geom_line(mapping = aes(x = x, y = value, group = name, color = name), lwd = 2, inherit.aes = FALSE) +
  geom_point()

# LDA from a package

In [None]:
library("palmerpenguins")

In [None]:
penguins <- penguins %>% filter(complete.cases(.))

In [None]:
d <- penguins %>% dplyr::select(c(bill_length_mm, bill_depth_mm, species))
head(d)

In [None]:
mod <- lda(species ~ ., data = d)
mod

In [None]:
plot_fit <- function(v1, v2, df = penguins, N = floor(sqrt(10000)), scaleit = TRUE, fmla = "species~.") {
  train_df <- df %>% dplyr::select(all_of(c("species", v1, v2)))
  if (scaleit) {
    train_df <- train_df %>% mutate(across(-species, ~ (.x - mean(.x)) / sd(.x)))
  }

  mod <- lda(formula = as.formula(fmla), data = train_df)

  combinations <- expand_grid(!!!map(train_df %>% dplyr::select(-species), ~ seq(min(.x), max(.x), length.out = N)))
  colnames(combinations) <- c(v1, v2)
  preds <- predict(mod, newdata = combinations)$class
  combinations$species <- preds

  ggplot(data = combinations, mapping = aes(x = !!sym(v1), y = !!sym(v2), fill = species, shape = species)) +
    geom_tile() +
    geom_point(data = train_df, size = 5) +
    coord_fixed()
}

In [None]:
options(repr.plot.width = 10, repr.plot.height = 10, repr.plot.res = 100)
plot_fit(v1 = "bill_length_mm", v2 = "bill_depth_mm", df = penguins)

In [None]:
plot_fit(v1 = "flipper_length_mm", v2 = "bill_depth_mm", df = penguins)

In [None]:
plot_fit(
  v1 = "bill_length_mm", v2 = "bill_depth_mm", df = penguins,
  fmla = "species~I(bill_length_mm^5)+I(bill_depth_mm^3)+I(bill_depth_mm^2)"
)

# QDA from a package


In [None]:
plot_fit_qda <- function(v1, v2, df = penguins, N = floor(sqrt(10000)), scaleit = TRUE, fmla = "species~.") {
  train_df <- df %>% dplyr::select(all_of(c("species", v1, v2)))
  if (scaleit) {
    train_df <- train_df %>% mutate(across(-species, ~ (.x - mean(.x)) / sd(.x)))
  }

  mod <- qda(formula = as.formula(fmla), data = train_df)

  combinations <- expand_grid(!!!map(train_df %>% dplyr::select(-species), ~ seq(min(.x), max(.x), length.out = N)))
  colnames(combinations) <- c(v1, v2)
  preds <- predict(mod, newdata = combinations)$class
  combinations$species <- preds

  ggplot(data = combinations, mapping = aes(x = !!sym(v1), y = !!sym(v2), fill = species, shape = species)) +
    geom_tile() +
    geom_point(data = train_df, size = 5) +
    coord_fixed()
}

In [None]:
options(repr.plot.width = 10, repr.plot.height = 10, repr.plot.res = 100)
plot_fit_qda(v1 = "bill_length_mm", v2 = "bill_depth_mm", df = penguins)

In [None]:
plot_fit_qda(v1 = "flipper_length_mm", v2 = "bill_depth_mm", df = penguins)

In [None]:
plot_fit_qda(
  v1 = "bill_length_mm", v2 = "bill_depth_mm", df = penguins,
  fmla = "species~I(bill_length_mm^5)+I(bill_depth_mm^3)+I(bill_depth_mm^2)"
)