###
# code 6 - Klassifikation
# 240521
###

if (!require("pacman")) install.packages("pacman"); library(pacman)
p_load(data.table)
p_load(ggplot2)
p_load(rpart)

# Datensatz laden und konvertieren ---
daten = fread("slides/slides7/autokauf.csv")
daten[, gekauft := factor(gekauft, levels = c(0, 1))]
daten[, geschlecht := NULL]
daten[, nutzer_id := NULL]

# Datensatz in Trainings- und Testdatensatz aufteilen ----
set.seed(1234)
daten[, split := sample(x = c(T,F), size = .N, replace = T, prob = c(0.75, 0.25))]
training_set <- daten[split == TRUE, .(alter, gehalt, gekauft)]
test_set <- daten[split == FALSE, .(alter, gehalt, gekauft)]

# Entscheidungsbaum-Modell trainieren ----
classifier <- rpart(gekauft ~ .,
                    data = training_set)

# Vorhersagen auf dem Testdatensatz ----
training_set[, prediction := predict(classifier,
                                     newdata = training_set,
                                     type = "class")]

## Konfusionsmatrix erstellen ----
training_set[, table(gekauft, prediction)]

## Visualisierung der Ergebnisse auf dem Trainingsdatensatz ----

# Vorhersagen für das Raster
training_grid <- CJ(alter = seq(min(training_set$alter) - 1,
                                max(training_set$alter) + 1, by = 1),
                    gehalt = seq(min(training_set$gehalt) - 1,
                                 max(training_set$gehalt) + 1, by = 1000))
training_grid[, prediction := predict(classifier,
                                      newdata = training_grid,
                                      type = "class")]

# Raster und Datenpunkte plotten
ggplot() +
  geom_tile(data = training_grid, aes(x = alter, y = gehalt, fill = prediction), alpha = 0.5) +
  geom_point(data = training_set, aes(x = alter, y = gehalt, color = gekauft), size = 2) +
  labs(title = "Training set", x = "Alter", y = "Gehalt") +
  scale_fill_manual(values = c("tomato", "springgreen3")) +
  scale_color_manual(values = c("red3", "green4")) +
  theme_minimal()
# ggsave("slides/slides7/decision_tree_classification_training_set.png", width = 15, height = 12, units = "cm")

# Vorhersagen auf dem Testdatensatz ----
test_set[, prediction := predict(classifier,
                                 newdata = test_set,
                                 type = "class")]

## Konfusionsmatrix erstellen ----
test_set[, table(gekauft, prediction)]

## Visualisierung der Ergebnisse auf dem Testdatensatz ----

# Vorhersagen für das Raster
test_grid <- CJ(alter = seq(min(test_set$alter) - 1,
                            max(test_set$alter) + 1, by = 1),
                gehalt = seq(min(test_set$gehalt) - 1,
                             max(test_set$gehalt) + 1, by = 1000))
test_grid[, prediction := predict(classifier,
                                  newdata = test_grid,
                                  type = "class")]

# Raster und Datenpunkte plotten
ggplot() +
  geom_tile(data = test_grid, aes(x = alter, y = gehalt, fill = prediction), alpha = 0.5) +
  geom_point(data = test_set, aes(x = alter, y = gehalt, color = gekauft), size = 2) +
  labs(title = "Decision Tree Classification (Test set)", x = "Age", y = "Estimated Salary") +
  scale_fill_manual(values = c("tomato", "springgreen3")) +
  scale_color_manual(values = c("red3", "green4")) +
  theme_minimal()

# Visualisierung des Entscheidungsbaums ----
plot(classifier)
text(classifier)


## K nearest neighbors ----

# Laden der benötigten Bibliotheken
library(class)

# Beispiel-Datensatz laden
data(iris)

# Vorbereitung der Daten
train_indices <- sample(1:nrow(iris), size = 0.7 * nrow(iris))
train_data <- iris[train_indices, ]
test_data <- iris[-train_indices, ]

# k-NN Modell trainieren und Vorhersagen treffen
knn_model <- knn(train = train_data[, -5], test = test_data[, -5], cl = train_data$Species, k = 3)

# Predictions
table(test_data$Species, knn_model)

# Laden der benötigten Bibliotheken
library(e1071)

# Beispiel-Datensatz laden
data(iris)

# Naive Bayes Modell trainieren
nb_model <- naiveBayes(Species ~ ., data = iris)

# Vorhersagen treffen
predictions <- predict(nb_model, iris)

# Evaluation
table(iris$Species, predictions)
