---
jupytext:
  cell_metadata_filter: -all
  formats: py:percent,md:myst,ipynb
  text_representation:
    extension: .md
    format_name: myst
    format_version: 0.13
    jupytext_version: 1.13.8
kernelspec:
  display_name: Python 3 (ipykernel)
  language: python
  name: python3
---

# Clustering {#clustering}

```{r setup-clustering, echo = FALSE, message = FALSE, warning = FALSE}
library(knitr)
library(RColorBrewer)
library(gridExtra)
library(cowplot)
library(broom)
library(egg) # ggarrange


#center breaks latex here
knitr::opts_chunk$set(warning = FALSE, fig.align = "default") 

# set the colors in the graphs, 
# some graphs with the code shown to students are hard coded 
cbbPalette <- c(brewer.pal(9, "Paired"))
cbpalette <- c("darkorange3", "dodgerblue3", "goldenrod1")

theme_update(axis.title = element_text(size = 12)) # modify axis label size in plots 
```

## Overview 
As part of exploratory data analysis, it is often helpful to see if there are
meaningful subgroups (or *clusters*) in the data. 
This grouping can be used for many purposes, 
such as generating new questions or improving predictive analyses. 
This chapter provides an introduction to clustering 
using the K-means algorithm,
including techniques to choose the number of clusters.

## Chapter learning objectives 
By the end of the chapter, readers will be able to do the following:

* Describe a case where clustering is appropriate, 
and what insight it might extract from the data.
* Explain the K-means clustering algorithm.
* Interpret the output of a K-means analysis.
* Differentiate between clustering and classification.
* Identify when it is necessary to scale variables before clustering, 
and do this using R.
* Perform K-means clustering in R using `kmeans`.
* Use the elbow method to choose the number of clusters for K-means.
* Visualize the output of K-means clustering in R using colored scatter plots.
* Describe the advantages, 
limitations and assumptions of the K-means clustering algorithm.

## Clustering
Clustering \index{clustering} is a data analysis task 
involving separating a data set into subgroups of related data. 
For example, we might use clustering to separate a
data set of documents into groups that correspond to topics, a data set of
human genetic information into groups that correspond to ancestral
subpopulations, or a data set of online customers into groups that correspond
to purchasing behaviors.  Once the data are separated, we can, for example,
use the subgroups to generate new questions about the data and follow up with a
predictive modeling exercise. In this course, clustering will be used only for
exploratory analysis, i.e., uncovering patterns in the data.  

Note that clustering is a fundamentally different kind of task 
than classification or regression. 
In particular, both classification and regression are *supervised tasks* 
\index{classification}\index{regression}\index{supervised} 
where there is a *response variable* (a category label or value), 
and we have examples of past data with labels/values 
that help us predict those of future data. 
By contrast, clustering is an *unsupervised task*, 
\index{unsupervised} as we are trying to understand 
and examine the structure of data without any response variable labels 
or values to help us. 
This approach has both advantages and disadvantages. 
Clustering requires no additional annotation or input on the data. 
For example, it would be nearly impossible to annotate 
all the articles on Wikipedia with human-made topic labels. 
However, we can still cluster the articles without this information 
to find groupings corresponding to topics automatically. 

Given that there is no response variable, it is not as easy to evaluate
the "quality" of a clustering.  With classification, we can use a test data set
to assess prediction performance. In clustering, there is not a single good
choice for evaluation. In this book, we will use visualization to ascertain the
quality of a clustering, and leave rigorous evaluation for more advanced
courses.  

As in the case of classification, 
there are many possible methods that we could use to cluster our observations 
to look for subgroups. 
In this book, we will focus on the widely used K-means \index{K-means} algorithm [@kmeans]. 
In your future studies, you might encounter hierarchical clustering,
principal component analysis, multidimensional scaling, and more; 
see the additional resources section at the end of this chapter 
for where to begin learning more about these other methods.

\newpage

> **Note:** There are also so-called *semisupervised* tasks, \index{semisupervised} 
> where only some of the data come with response variable labels/values, 
> but the vast majority don't. 
> The goal is to try to uncover underlying structure in the data 
> that allows one to guess the missing labels. 
> This sort of task is beneficial, for example, 
> when one has an unlabeled data set that is too large to manually label, 
> but one is willing to provide a few informative example labels as a "seed" 
> to guess the labels for all the data.

**An illustrative example** 

Here we will present an illustrative example using a data set \index{Palmer penguins} from
[the `palmerpenguins` R package](https://allisonhorst.github.io/palmerpenguins/) [@palmerpenguins]. This 
data set was collected by Dr. Kristen Gorman and
the Palmer Station, Antarctica Long Term Ecological Research Site, and includes
measurements for adult penguins found near there [@penguinpaper]. We have
modified the data set for use in this chapter. Here we will focus on using two
variables&mdash;penguin bill and flipper length, both in millimeters&mdash;to determine whether 
there are distinct types of penguins in our data.
Understanding this might help us with species discovery and classification in a data-driven
way.

```{r 09-penguins, echo = FALSE, message = FALSE, warning = FALSE, fig.cap = "Gentoo penguin.", out.width="60%", fig.align = "center", fig.retina = 2}
# image source: https://commons.wikimedia.org/wiki/File:Gentoo_Penguin._(8671680772).jpg
knitr::include_graphics("img/gentoo.jpg")
```

To learn about K-means clustering 
we will work with `penguin_data` in this chapter.
`penguin_data` is a subset of 18 observations of the original data, 
which has already been standardized 
(remember from Chapter \@ref(classification) 
that scaling is part of the standardization process). 
We will discuss scaling for K-means in more detail later in this chapter. 
\index{mutate}\index{read function!read\_csv} 

Before we get started, we will load the `tidyverse` metapackage 
as well as set a random seed.
This will ensure we have access to the functions we need 
and that our analysis will be reproducible.
As we will learn in more detail later in the chapter, 
setting the seed here is important 
because the K-means clustering algorithm uses random numbers.

```{r 10-toy-example-data, echo = FALSE, message = FALSE, warning = FALSE}
library(tidyverse)

data <- read_csv("data/toy_penguins.csv") |>
  mutate(cluster = as_factor(cluster)) |>
  mutate(flipper_length_standardized = as.double(scale(flipper_length_mm)), 
         bill_length_standardized = as.double(scale(bill_length_mm)))

penguin_data <- data |> select(flipper_length_standardized, 
bill_length_standardized)

write_csv(penguin_data, "data/penguins_standardized.csv")
```

\index{seed!set.seed}

```{r 10-clustering, warning = FALSE, message = FALSE}
library(tidyverse)
set.seed(1)
```

Now we can load and preview the data.

```{r, message = FALSE, warning = FALSE}
penguin_data <- read_csv("data/penguins_standardized.csv")
penguin_data
```

Next, we can create a scatter plot using this data set 
to see if we can detect subtypes or groups in our data set.

\newpage

```{r 10-toy-example-plot, warning = FALSE, fig.height = 3.25, fig.width = 3.5, fig.align = "center", fig.pos = "H", out.extra="", fig.cap = "Scatter plot of standardized bill length versus standardized flipper length."}
ggplot(data, aes(x = flipper_length_standardized, 
                 y = bill_length_standardized)) +
  geom_point() +
  xlab("Flipper Length (standardized)") +
  ylab("Bill Length (standardized)") + 
  theme(text = element_text(size = 12))
```

Based \index{ggplot}\index{ggplot!geom\_point} on the visualization 
in Figure \@ref(fig:10-toy-example-plot), 
we might suspect there are a few subtypes of penguins within our data set.
We can see roughly 3 groups of observations in Figure \@ref(fig:10-toy-example-plot),
including:

1. a small flipper and bill length group,
2. a small flipper length, but large bill length group, and
3. a large  flipper and bill length group.

Data visualization is a great tool to give us a rough sense of such patterns
when we have a small number of variables. 
But if we are to group data&mdash;and select the number of groups&mdash;as part of 
a reproducible analysis, we need something a bit more automated.
Additionally, finding groups via visualization becomes more difficult 
as we increase the number of variables we consider when clustering.
The way to rigorously separate the data into groups 
is to use a clustering algorithm.
In this chapter, we will focus on the *K-means* algorithm, 
\index{K-means} a widely used and often very effective clustering method, 
combined with the *elbow method* \index{elbow method} 
for selecting the number of clusters. 
This procedure will separate the data into groups;
Figure \@ref(fig:10-toy-example-clustering) shows these groups
denoted by colored scatter points.

```{r 10-toy-example-clustering, echo = FALSE, warning = FALSE, fig.height = 3.25, fig.width = 4.25, fig.align = "center", fig.cap = "Scatter plot of standardized bill length versus standardized flipper length with colored groups."}
ggplot(data, aes(y = bill_length_standardized, 
                 x = flipper_length_standardized, color = cluster)) +
  geom_point() +
  xlab("Flipper Length (standardized)") +
  ylab("Bill Length (standardized)") + 
  scale_color_manual(values= c("darkorange3", "dodgerblue3", "goldenrod1"))
```

What are the labels for these groups? Unfortunately, we don't have any. K-means,
like almost all clustering algorithms, just outputs meaningless "cluster labels"
that are typically whole numbers: 1, 2, 3, etc. But in a simple case like this,
where we can easily visualize the clusters on a scatter plot, we can give
human-made labels to the groups using their positions on
the plot:

- small flipper length and small bill length (<font color="#D55E00">orange cluster</font>), 
- small flipper length and large bill length (<font color="#0072B2">blue cluster</font>).
- and large flipper length and large bill  length (<font color="#F0E442">yellow cluster</font>).

Once we have made these determinations, we can use them to inform our species
classifications or ask further questions about our data. For example, we might
be interested in understanding the relationship between flipper length and bill
length, and that relationship may differ depending on the type of penguin we
have. 

## K-means 

### Measuring cluster quality

```{r 10-toy-example-clus1, echo = FALSE, message = FALSE, warning = FALSE}
library(tidyverse)

clus1 <- filter(data, cluster == 2) |>
  select(bill_length_standardized, flipper_length_standardized)
```

The K-means algorithm is a procedure that groups data into K clusters.
It starts with an initial clustering of the data, and then iteratively
improves it by making adjustments to the assignment of data
to clusters until it cannot improve any further. But how do we measure
the "quality" of a clustering, and what does it mean to improve it? 
In K-means clustering, we measure the quality of a cluster by its
\index{within-cluster sum-of-squared-distances|see{WSSD}}\index{WSSD}
*within-cluster sum-of-squared-distances* (WSSD). Computing this involves two steps.
First, we find the cluster centers by computing the mean of each variable 
over data points in the cluster. For example, suppose we have a 
cluster containing four observations, and we are using two variables, $x$ and $y$, to cluster the data.
Then we would compute the coordinates, $\mu_x$ and $\mu_y$, of the cluster center via

$$\mu_x = \frac{1}{4}(x_1+x_2+x_3+x_4) \quad \mu_y = \frac{1}{4}(y_1+y_2+y_3+y_4).$$

In the first cluster from the example, there are `r nrow(clus1)` data points. These are shown with their cluster center 
(`r paste("flipper_length_standardized =", round(mean(clus1$flipper_length_standardized),2))` and `r paste("bill_length_standardized =", round(mean(clus1$bill_length_standardized),2))`) highlighted 
in Figure \@ref(fig:10-toy-example-clus1-center).

(ref:10-toy-example-clus1-center) Cluster 1 from the `penguin_data` data set example. Observations are in blue, with the cluster center highlighted in red.

```{r 10-toy-example-clus1-center, echo = FALSE, warning = FALSE, fig.height = 3.25, fig.width = 3.5, fig.align = "center", fig.cap = "(ref:10-toy-example-clus1-center)"}
base <- ggplot(data, aes(x = flipper_length_standardized, y = bill_length_standardized)) +
  geom_point() +
  xlab("Flipper Length (standardized)") +
  ylab("Bill Length (standardized)")

base <- ggplot(clus1) +
  geom_point(aes(y = bill_length_standardized, x = flipper_length_standardized),  
  col = "dodgerblue3") +
  labs(x = "Flipper Length (standardized)", y = "Bill Length (standardized)") +
  xlim(c(
    min(clus1$flipper_length_standardized) - 0.25 * 
      sd(clus1$flipper_length_standardized),
    max(clus1$flipper_length_standardized) + 0.25 * 
      sd(clus1$flipper_length_standardized)
  )) +
  ylim(c(
    min(clus1$bill_length_standardized) - 0.25 * 
      sd(clus1$bill_length_standardized),
    max(clus1$bill_length_standardized) + 0.25 * 
      sd(clus1$bill_length_standardized)
  )) +
  geom_point(aes(y = mean(bill_length_standardized), 
                 x = mean(flipper_length_standardized)), 
             color = "#F8766D", 
             size = 5) +
  theme(legend.position = "none")

base
```

The second step in computing the WSSD is to add up the squared distance 
\index{distance!K-means} between each point in the cluster 
and the cluster center.
We use the straight-line / Euclidean distance formula 
that we learned about in Chapter \@ref(classification).
In the `r nrow(clus1)`-observation cluster example above, 
we would compute the WSSD $S^2$ via

\begin{align*}
S^2 = \left((x_1 - \mu_x)^2 + (y_1 - \mu_y)^2\right) + \left((x_2 - \mu_x)^2 + (y_2 - \mu_y)^2\right) + \\ \left((x_3 - \mu_x)^2 + (y_3 - \mu_y)^2\right)  +  \left((x_4 - \mu_x)^2 + (y_4 - \mu_y)^2\right).
\end{align*}

These distances are denoted by lines in Figure \@ref(fig:10-toy-example-clus1-dists) for the first cluster of the penguin data example. 

(ref:10-toy-example-clus1-dists) Cluster 1 from the `penguin_data` data set example. Observations are in blue, with the cluster center highlighted in red. The distances from the observations to the cluster center are represented as black lines.

```{r 10-toy-example-clus1-dists, echo = FALSE, warning = FALSE, fig.height = 3.25, fig.width = 3.5, fig.align = "center", fig.cap = "(ref:10-toy-example-clus1-dists)"}
base <- ggplot(clus1) +
  geom_point(aes(y = bill_length_standardized, 
                 x = flipper_length_standardized),
             col = "dodgerblue3") +
  labs(x = "Flipper Length (standardized)", y = "Bill Length (standardized)") +
  theme(legend.position = "none") 

mn <- clus1 |> 
  summarize(flipper_length_standardized = mean(flipper_length_standardized), 
            bill_length_standardized = mean(bill_length_standardized))
for (i in 1:nrow(clus1)) {
  base <- base + geom_segment(
    x = unlist(mn[1, "flipper_length_standardized"]), 
    y = unlist(mn[1, "bill_length_standardized"]),
    xend = unlist(clus1[i, "flipper_length_standardized"]), 
    yend = unlist(clus1[i, "bill_length_standardized"])
  )
}
base <- base + 
  geom_point(aes(y = mean(bill_length_standardized), 
                 x = mean(flipper_length_standardized)), 
             color = "#F8766D", 
             size = 5)
base
```

The larger the value of $S^2$, the more spread out the cluster is, since large $S^2$ means that points are far from the cluster center.
Note, however, that "large" is relative to *both* the scale of the variables for clustering *and* the number of points in the cluster. A cluster where points are very close to the center might still have a large $S^2$ if there are many data points in the cluster.

After we have calculated the WSSD for all the clusters, 
we sum them together to get the *total WSSD*.
For our example, 
this means adding up all the squared distances for the 18 observations.
These distances are denoted by black lines in
Figure \@ref(fig:10-toy-example-all-clus-dists).

(ref:10-toy-example-all-clus-dists) All clusters from the `penguin_data` data set example. Observations are in orange, blue, and yellow with the cluster center highlighted in red. The distances from the observations to each of the respective cluster centers are represented as black lines.

```{r 10-toy-example-all-clus-dists, echo = FALSE, warning = FALSE, fig.height = 3.25, fig.width = 4.25, fig.align = "center", fig.cap = "(ref:10-toy-example-all-clus-dists)"}


all_clusters_base <- data |>
  ggplot(aes(y = bill_length_standardized,
             x = flipper_length_standardized,
             color = cluster)) +
  geom_point() +
  xlab("Flipper Length (standardized)") +
  ylab("Bill Length (standardized)") + 
  scale_color_manual(values= c("darkorange3", 
                               "dodgerblue3", 
                               "goldenrod1"))
cluster_centers <- tibble(x = c(0, 0, 0),
                          y = c(0, 0, 0))

for (cluster_number in seq_along(1:3)) {
  
  clus <- filter(data, cluster == cluster_number) |>
    select(bill_length_standardized, flipper_length_standardized)
  
  mn <- clus |> 
    summarize(flipper_length_standardized = mean(flipper_length_standardized),
              bill_length_standardized = mean(bill_length_standardized))
  
  for (i in 1:nrow(clus)) {
    all_clusters_base <- all_clusters_base + 
      geom_segment(x = unlist(mn[1, "flipper_length_standardized"]), 
                   y = unlist(mn[1, "bill_length_standardized"]),
      xend = unlist(clus[i, "flipper_length_standardized"]), 
      yend = unlist(clus[i, "bill_length_standardized"]),
      color = "black"
    )
  }
  
  #all_clusters_base <- all_clusters_base + 
  #  geom_point(aes(y = mean(clus$bill_length_standardized), 
  #                 x = mean(clus$flipper_length_standardized)), 
  #                 color = "#F8766D", size = 3)
  #print(mean(clus$bill_length_standardized))
  #print(mean(clus$flipper_length_standardized))
  cluster_centers[cluster_number, 1] <- mean(clus$flipper_length_standardized)
  cluster_centers[cluster_number, 2] <- mean(clus$bill_length_standardized)
}

all_clusters_base <- all_clusters_base + 
  geom_point(aes(y = cluster_centers$y[1], 
                 x = cluster_centers$x[1]), 
             color = "#F8766D", size = 3) +
  geom_point(aes(y = cluster_centers$y[2], 
                 x = cluster_centers$x[2]), 
             color = "#F8766D", size = 3) +
  geom_point(aes(y = cluster_centers$y[3], 
                 x = cluster_centers$x[3]), 
             color = "#F8766D", size = 3)

all_clusters_base
```

\newpage

### The clustering algorithm

We begin the K-means \index{K-means!algorithm} algorithm by picking K, 
and randomly assigning a roughly equal number of observations 
to each of the K clusters.
An example random initialization is shown in Figure \@ref(fig:10-toy-kmeans-init).

```{r 10-toy-kmeans-init, echo = FALSE, message = FALSE, warning = FALSE, fig.height = 3.5, fig.width = 3.75, fig.align = "center", fig.cap = "Random initialization of labels."}
set.seed(14)
penguin_data["label"] <- factor(sample(1:3, nrow(penguin_data), replace = TRUE))

plt_lbl <- ggplot(penguin_data, aes(y = bill_length_standardized, 
                                    x = flipper_length_standardized, 
                                    color = label)) +
  geom_point(size = 2) +
  xlab("Flipper Length (standardized)") +
  ylab("Bill Length (standardized)") +
  theme(legend.position = "none") + 
  scale_color_manual(values= cbpalette)

plt_lbl
```

Then K-means consists of two major steps that attempt to minimize the
sum of WSSDs over all the clusters, i.e., the \index{WSSD!total} *total WSSD*:

1. **Center update:** Compute the center of each cluster.
2. **Label update:** Reassign each data point to the cluster with the nearest center.

These two steps are repeated until the cluster assignments no longer change.
We show what the first four iterations of K-means would look like in  
Figure \@ref(fig:10-toy-kmeans-iter). 
There each row corresponds to an iteration,
where the left column depicts the center update, 
and the right column depicts the reassignment of data to clusters.

(ref:10-toy-kmeans-iter) First four iterations of K-means clustering on the `penguin_data` example data set. Each pair of plots corresponds to an iteration. Within the pair, the first plot depicts the center update, and the second plot depicts the reassignment of data to clusters. Cluster centers are indicated by larger points that are outlined in black.

```{r 10-toy-kmeans-iter, echo = FALSE, warning = FALSE, fig.height = 4.5, fig.width = 8, fig.align = "center", fig.cap = "(ref:10-toy-kmeans-iter)"}
list_plot_cntrs <- vector(mode = "list", length = 4)
list_plot_lbls <- vector(mode = "list", length = 4)

for (i in 1:4) {
  # compute centers
  centers <- penguin_data |>
    group_by(label) |>
    summarize_all(funs(mean))
  nclus <- nrow(centers)
  # replot with centers
  plt_ctr <- ggplot(penguin_data, aes(y = bill_length_standardized, 
                                      x = flipper_length_standardized, 
                                      color = label)) +
    geom_point(size = 2) +
    xlab("Flipper Length\n(standardized)") +
    ylab("Bill Length\n(standardized)") +
    theme(legend.position = "none") +
    scale_color_manual(values= cbpalette) + 
    geom_point(data = centers, 
               aes(y = bill_length_standardized, 
                                   x = flipper_length_standardized, 
                                   fill = label), 
               size = 4, 
               shape = 21, 
               stroke = 1, 
               color = "black", 
               fill = cbpalette) +
    annotate("text", x = -0.5, y = 1.5, label = paste0("Iteration ", i), size = 5)+ 
    theme(text = element_text(size = 14), axis.title=element_text(size=14)) 
  
  if (i == 1 | i == 2) {
    plt_ctr <- plt_ctr +
      ggtitle("Center Update")
  }
  
  # reassign labels
  dists <- rbind(centers, penguin_data) |>
    select("flipper_length_standardized", "bill_length_standardized") |>
    dist() |>
    as.matrix()
  dists <- as_tibble(dists[-(1:nclus), 1:nclus])
  penguin_data <- penguin_data |> 
    mutate(label = apply(dists, 1, function(x) names(x)[which.min(x)]))

  plt_lbl <- ggplot(penguin_data, 
                    aes(y = bill_length_standardized, 
                        x = flipper_length_standardized, 
                        color = label)) +
    geom_point(size = 2) +
    xlab("Flipper Length\n(standardized)") +
    ylab("Bill Length\n(standardized)") +
    theme(legend.position = "none") +
    scale_color_manual(values= cbpalette) +
    geom_point(data = centers, 
               aes(y = bill_length_standardized, 
                   x = flipper_length_standardized, fill = label), 
               size = 4, 
               shape = 21, 
               stroke = 1, 
               color = "black", 
               fill = cbpalette) +
    annotate("text", x = -0.5, y = 1.5, label = paste0("Iteration ", i), size = 5) + 
    theme(text = element_text(size = 14), axis.title=element_text(size=14)) 

  if (i == 1 | i ==2) {
    plt_lbl <- plt_lbl +
      ggtitle("Label Update")
  }
  
  list_plot_cntrs[[i]] <- plt_ctr
  list_plot_lbls[[i]] <- plt_lbl
}

iter_plot_list <- c(list_plot_cntrs[1], list_plot_lbls[1],
                    list_plot_cntrs[2], list_plot_lbls[2],
                    list_plot_cntrs[3], list_plot_lbls[3],
                    list_plot_cntrs[4], list_plot_lbls[4])

ggarrange(iter_plot_list[[1]] +
               theme(axis.text.x = element_blank(),
                     axis.ticks.x = element_blank(),
                     axis.title.x = element_blank(), 
                     plot.margin = margin(r = 2, b = 2)), 
          iter_plot_list[[2]] + 
               theme(axis.text.y = element_blank(),
                     axis.ticks.y = element_blank(),
                     axis.title.y = element_blank(),
                     axis.text.x = element_blank(),
                     axis.ticks.x = element_blank(),
                     axis.title.x = element_blank(),
                     plot.margin = margin(r = 2, l = 2, b = 2) ), 
          iter_plot_list[[3]] + 
               theme(axis.text.y = element_blank(),
                     axis.ticks.y = element_blank(),
                     axis.title.y = element_blank(),
                     axis.text.x = element_blank(),
                     axis.ticks.x = element_blank(),
                     axis.title.x = element_blank(),
                     plot.margin = margin(r = 2, l = 2, b = 2)  ),
          iter_plot_list[[4]] + 
            theme(axis.text.y = element_blank(),
                     axis.ticks.y = element_blank(),
                     axis.title.y = element_blank(), 
                     axis.text.x = element_blank(),
                     axis.ticks.x = element_blank(),
                     axis.title.x = element_blank(),
                     plot.margin = margin(l = 2, b = 2)  ),
          iter_plot_list[[5]] +
               theme(plot.margin = margin(r = 2, t = 2)), 
          iter_plot_list[[6]] + 
               theme(axis.text.y = element_blank(),
                     axis.ticks.y = element_blank(),
                     axis.title.y = element_blank(),
                     plot.margin = margin(r = 2, l = 2, t = 2) ), 
          iter_plot_list[[7]] + 
               theme(axis.text.y = element_blank(),
                     axis.ticks.y = element_blank(),
                     axis.title.y = element_blank(),
                     plot.margin = margin(r = 2, l = 2, t = 2)  ),
          iter_plot_list[[8]] + theme(axis.text.y = element_blank(),
                     axis.ticks.y = element_blank(),
                     axis.title.y = element_blank(), 
                     plot.margin = margin(l = 2, t = 2)  ),
           nrow = 2)
```

Note that at this point, we can terminate the algorithm since none of the assignments changed
in the fourth iteration; both the centers and labels will remain the same from this point onward.

> **Note:** Is K-means *guaranteed* to stop at some point, or could it iterate forever? As it turns out,
> thankfully, the answer is that K-means \index{K-means!termination} is guaranteed to stop after *some* number of iterations. For the interested reader, the
> logic for this has three steps: (1) both the label update and the center update decrease total WSSD in each iteration,
> (2) the total WSSD is always greater than or equal to 0, and (3) there are only a finite number of possible
> ways to assign the data to clusters. So at some point, the total WSSD must stop decreasing, which means none of the assignments
> are changing, and the algorithm terminates.

What kind of data is suitable for K-means clustering? 
In the simplest version of K-means clustering that we have presented here,
the straight-line distance is used to measure the 
distance between observations and cluster centers. 
This means that only quantitative data should be used with this algorithm.
There are variants on the K-means algorithm, 
as well as other clustering algorithms entirely, 
that use other distance metrics 
to allow for non-quantitative data to be clustered. 
These, however, are beyond the scope of this book.

### Random restarts

Unlike the classification and regression models we studied in previous chapters, K-means \index{K-means!restart, nstart} can get "stuck" in a bad solution.
For example, Figure \@ref(fig:10-toy-kmeans-bad-init) illustrates an unlucky random initialization by K-means.

```{r 10-toy-kmeans-bad-init, echo = FALSE, warning = FALSE, message = FALSE, fig.height = 3.25, fig.width = 3.75, fig.pos = "H", out.extra="", fig.align = "center", fig.cap = "Random initialization of labels."}
penguin_data <- penguin_data |>
  mutate(label = as_factor(c(3L, 3L, 1L, 1L, 2L, 1L, 2L, 1L, 1L, 
                             1L, 3L, 1L, 2L, 2L, 2L, 3L, 3L, 3L)))

plt_lbl <- ggplot(penguin_data, aes(y = bill_length_standardized, 
                                    x = flipper_length_standardized, 
                                    color = label)) +
  geom_point(size = 2) +
  xlab("Flipper Length (standardized)") +
  ylab("Bill Length (standardized)") +
  scale_color_manual(values= cbpalette) +
  theme(legend.position = "none")

plt_lbl
```

Figure \@ref(fig:10-toy-kmeans-bad-iter) shows what the iterations of K-means would look like with the unlucky random initialization shown in Figure \@ref(fig:10-toy-kmeans-bad-init).

(ref:10-toy-kmeans-bad-iter) First five iterations of K-means clustering on the `penguin_data` example data set with a poor random initialization. Each pair of plots corresponds to an iteration. Within the pair, the first plot depicts the center update, and the second plot depicts the reassignment of data to clusters. Cluster centers are indicated by larger points that are outlined in black.

```{r 10-toy-kmeans-bad-iter, echo = FALSE, warning = FALSE, message = FALSE, fig.height = 6.75, fig.width = 8, fig.pos = "H", out.extra="", fig.align = "center", fig.cap = "(ref:10-toy-kmeans-bad-iter)"}
list_plot_cntrs <- vector(mode = "list", length = 5)
list_plot_lbls <- vector(mode = "list", length = 5)

for (i in 1:5) {
  # compute centers
  centers <- penguin_data |>
    group_by(label) |>
    summarize_all(funs(mean))
  nclus <- nrow(centers)
  # replot with centers
  plt_ctr <- ggplot(penguin_data, aes(y = bill_length_standardized, 
                                      x = flipper_length_standardized, 
                                      color = label)) +
    geom_point(size = 2) +
    xlab("Flipper Length\n(standardized)") +
    ylab("Bill Length\n(standardized)") +
    theme(legend.position = "none") +
    scale_color_manual(values= cbpalette) + 
    geom_point(data = centers, aes(y = bill_length_standardized, 
                                   x = flipper_length_standardized, 
                                   fill = label), 
               size = 4, 
               shape = 21, 
               stroke = 1, 
               color = "black", 
               fill = cbpalette) +
    annotate("text", x = -0.5, y = 1.5, label = paste0("Iteration ", i), size = 5) + 
    theme(text = element_text(size = 14), axis.title=element_text(size=14)) 

  if (i == 1 | i == 2) {
    plt_ctr <- plt_ctr +
      ggtitle("Center Update")
  }
  
  # reassign labels
  dists <- rbind(centers, penguin_data) |>
    select("flipper_length_standardized", "bill_length_standardized") |>
    dist() |>
    as.matrix()
  dists <- as_tibble(dists[-(1:nclus), 1:nclus])
  penguin_data <- penguin_data |> 
    mutate(label = apply(dists, 1, function(x) names(x)[which.min(x)]))

  plt_lbl <- ggplot(penguin_data, aes(y = bill_length_standardized, 
                                      x = flipper_length_standardized, 
                                      color = label)) +
    geom_point(size = 2) +
    xlab("Flipper Length\n(standardized)") +
    ylab("Bill Length\n(standardized)") +
    theme(legend.position = "none") +
    scale_color_manual(values= cbpalette) +
    geom_point(data = centers, aes(y = bill_length_standardized, 
                                   x = flipper_length_standardized, 
                                   fill = label), 
               size = 4, 
               shape = 21, 
               stroke = 1, 
               color = "black", 
               fill = cbpalette) +
    annotate("text", x = -0.5, y = 1.5, label = paste0("Iteration ", i), size = 5) + 
    theme(text = element_text(size = 14), axis.title=element_text(size=14)) 

  if (i == 1 | i == 2) {
    plt_lbl <- plt_lbl +
      ggtitle("Label Update")
  }
  
  list_plot_cntrs[[i]] <- plt_ctr
  list_plot_lbls[[i]] <- plt_lbl
}

iter_plot_list <- c(list_plot_cntrs[1], list_plot_lbls[1],
                    list_plot_cntrs[2], list_plot_lbls[2],
                    list_plot_cntrs[3], list_plot_lbls[3],
                    list_plot_cntrs[4], list_plot_lbls[4],
                    list_plot_cntrs[5], list_plot_lbls[5])

ggarrange(iter_plot_list[[1]] +
               theme(axis.text.x = element_blank(),  #remove x axis
                     axis.ticks.x = element_blank(),
                     axis.title.x = element_blank(), 
                     plot.margin = margin(r = 2, b = 2)), # change margins
          iter_plot_list[[2]] + 
               theme(axis.text.y = element_blank(),
                     axis.ticks.y = element_blank(),
                     axis.title.y = element_blank(),
                     axis.text.x = element_blank(),
                     axis.ticks.x = element_blank(),
                     axis.title.x = element_blank(),
                     plot.margin = margin(r = 2, l = 2, b = 2) ), 
          iter_plot_list[[3]] + 
               theme(axis.text.y = element_blank(),
                     axis.ticks.y = element_blank(),
                     axis.title.y = element_blank(),
                     axis.text.x = element_blank(),
                     axis.ticks.x = element_blank(),
                     axis.title.x = element_blank(),
                     plot.margin = margin(r = 2, l = 2, b = 2)),
          iter_plot_list[[4]] + 
            theme(axis.text.y = element_blank(),
                     axis.ticks.y = element_blank(),
                     axis.title.y = element_blank(), 
                     axis.text.x = element_blank(),
                     axis.ticks.x = element_blank(),
                     axis.title.x = element_blank(),
                     plot.margin = margin(l = 2, b = 2) ),
          iter_plot_list[[5]] +
               theme(axis.text.x = element_blank(),
                     axis.ticks.x = element_blank(),
                     axis.title.x = element_blank(),
                     plot.margin = margin(r = 2, t = 2, b = 2)),
          iter_plot_list[[6]] + 
               theme(axis.text.y = element_blank(),
                     axis.ticks.y = element_blank(),
                     axis.title.y = element_blank(),
                     axis.text.x = element_blank(),
                     axis.ticks.x = element_blank(),
                     axis.title.x = element_blank(),
                     plot.margin = margin(r = 2, l = 2, t = 2, b = 2) ), 
          iter_plot_list[[7]] + 
               theme(axis.text.y = element_blank(),
                     axis.ticks.y = element_blank(),
                     axis.title.y = element_blank(),
                     plot.margin = margin(r = 2, l = 2, t = 2, b = 2)  ),
          iter_plot_list[[8]] + theme(axis.text.y = element_blank(),
                     axis.ticks.y = element_blank(),
                     axis.title.y = element_blank(), 
                     plot.margin = margin(l = 2, t = 2, b = 2)),  
          ggplot() + theme_void(), ggplot() + theme_void(), ggplot() + theme_void(), ggplot() + theme_void(), # adding third row of empty plots to change space between third and fourth row
          iter_plot_list[[9]] + 
               theme(plot.margin = margin(r = 2)),
          iter_plot_list[[10]] + 
            theme(axis.text.y = element_blank(),
                     axis.ticks.y = element_blank(),
                     axis.title.y = element_blank(), 
                     plot.margin = margin(l = 2)  ),
         heights = c(3, 3, -1, 3),
          ncol = 4)
```

This looks like a relatively bad clustering of the data, but K-means cannot improve it.
To solve this problem when clustering data using K-means, we should randomly re-initialize the labels a few times, run K-means for each initialization,
and pick the clustering that has the lowest final total WSSD.

### Choosing K

In order to cluster data using K-means, 
we also have to pick the number of clusters, K.
But unlike in classification, we have no response variable 
and cannot perform cross-validation with some measure of model prediction error.
Further, if K is chosen too small, then multiple clusters get grouped together;
if K is too large, then clusters get subdivided. 
In both cases, we will potentially miss interesting structure in the data. 
Figure \@ref(fig:10-toy-kmeans-vary-k) illustrates the impact of K 
on K-means clustering of our penguin flipper and bill length data 
by showing the different clusterings for K's ranging from 1 to 9.

```{r 10-toy-kmeans-vary-k, echo = FALSE, warning = FALSE, fig.height = 6.25, fig.width = 6, fig.pos = "H", out.extra="", fig.cap = "Clustering of the penguin data for K clusters ranging from 1 to 9. Cluster centers are indicated by larger points that are outlined in black."}
set.seed(3)

kclusts <- tibble(k = 1:9) |>
  mutate(
    kclust = map(k, ~ kmeans(penguin_data[-3], .x)),
    tidied = map(kclust, tidy),
    glanced = map(kclust, glance),
    augmented = map(kclust, augment, penguin_data[-3])
  )

clusters <- kclusts |>
  unnest(tidied)

assignments <- kclusts |>
  unnest(augmented)

clusterings <- kclusts |>
  unnest(glanced, .drop = TRUE)

clusters_levels <- c("1 Cluster", 
                     "2 Clusters", 
                     "3 Clusters", 
                     "4 Clusters", 
                     "5 Clusters", 
                     "6 Clusters", 
                     "7 Clusters", 
                     "8 Clusters", 
                     "9 Clusters")

assignments$k <- factor(assignments$k)
levels(assignments$k) <- clusters_levels

clusters$k <- factor(clusters$k)
levels(clusters$k) <- clusters_levels

p1 <- ggplot(assignments, aes(flipper_length_standardized, 
                              bill_length_standardized)) +
  geom_point(aes(color = .cluster, size = I(2))) +
  facet_wrap(~k) +   scale_color_manual(values = cbbPalette) +
  labs(x = "Flipper Length (standardized)", 
       y = "Bill Length (standardized)", 
       color = "Cluster") +
  theme(legend.position = "none") +
  geom_point(data = clusters, 
             aes(fill = cluster), 
             color = "black", 
             size = 4, 
             shape = 21, 
             stroke = 1) + 
  scale_fill_manual(values = cbbPalette) +     
  theme(text = element_text(size = 12), axis.title=element_text(size=12)) 


p1
```

If we set K less than 3, then the clustering merges separate groups of data; this causes a large 
total WSSD, since the cluster center (denoted by an "x") is not close to any of the data in the cluster. On 
the other hand, if we set K greater than 3, the clustering subdivides subgroups of data; this does indeed still 
decrease the total WSSD, but by only a *diminishing amount*. If we plot the total WSSD versus the number of 
clusters, we see that the decrease in total WSSD levels off (or forms an "elbow shape") \index{elbow method} when we reach roughly 
the right number of clusters (Figure \@ref(fig:10-toy-kmeans-elbow)).

```{r 10-toy-kmeans-elbow, echo = FALSE, warning = FALSE, fig.align = 'center', fig.height = 3.25, fig.width = 4.25, fig.pos = "H", out.extra="", fig.cap = "Total WSSD for K clusters ranging from 1 to 9."}
p2 <- ggplot(clusterings, aes(x = k, y = tot.withinss)) +
  geom_point(size = 2) +
  geom_line() +
  # annotate(geom = "line", x = 4, y = 35, xend = 2.65, yend = 27, arrow = arrow(length = unit(2, "mm"))) +
  geom_segment(aes(x = 4, y = 17, 
                   xend = 3.1, 
                   yend = 6), 
               arrow = arrow(length = unit(0.2, "cm"))) +
  annotate("text", x = 4.4, y = 19, label = "Elbow", size = 7, color = "blue") +
  labs(x = "Number of Clusters", y = "Total WSSD") +
  #theme(text = element_text(size = 20)) +
  scale_x_continuous(breaks = 1:9)
p2
```

## Data pre-processing for K-means

Similar to K-nearest neighbors classification and regression, K-means 
clustering uses straight-line distance to decide which points are similar to 
each other. Therefore, the *scale* of each of the variables in the data
will influence which cluster data points end up being assigned.
Variables with a large scale will have a much larger 
effect on deciding cluster assignment than variables with a small scale. 
To address this problem, we typically standardize \index{standardization!K-means}\index{K-means!standardization} our data before clustering,
which ensures that each variable has a mean of 0 and standard deviation of 1.
The `scale` function in R can be used to do this. 
We show an example of how to use this function 
below using an unscaled and unstandardized version of the data set in this chapter.

```{r 10-get-unscaled-data, echo = FALSE, message = FALSE, warning = FALSE}
unstandardized_data <- read_csv("data/toy_penguins.csv") |>
  select(bill_length_mm, flipper_length_mm)

write_csv(unstandardized_data, "data/penguins_not_standardized.csv")
```

First, here is what the raw (i.e., not standardized) data looks like:

```{r}
not_standardized_data <- read_csv("data/penguins_not_standardized.csv")
not_standardized_data
```

And then we apply the `scale` function to every column in the data frame 
using `mutate` + `across`.

```{r 10-mapdf-scale-data}
standardized_data <- not_standardized_data |>
  mutate(across(everything(), scale))

standardized_data
```

## K-means in R

To perform K-means clustering in R, we use the `kmeans` function. \index{K-means!kmeans function} It takes at
least two arguments: the data frame containing the data you wish to cluster,
and K, the number of clusters (here we choose K = 3). Note that since the K-means
algorithm uses a random initialization of assignments, but since we set the random seed
earlier, the clustering will be reproducible.

```{r 10-kmeans-seed, echo = FALSE, warning = FALSE, message = FALSE}
# hidden seed
set.seed(1234)
```

```{r 10-kmeans}
penguin_clust <- kmeans(standardized_data, centers = 3)
penguin_clust
```

As you can see above, the clustering object returned by `kmeans` has a lot of information
that can be used to visualize the clusters, pick K, and evaluate the total WSSD.
To obtain this information in a tidy format, we will call in help 
from the `broom` package. \index{broom} Let's start by visualizing the clustering
as a colored scatter plot. To do that,
we use the `augment` function, \index{K-means!augment} \index{augment} which takes in the model and the original data
frame, and returns a data frame with the data and the cluster assignments for
each point:

```{r 10-plot-clusters-1}
library(broom)

clustered_data <- augment(penguin_clust, standardized_data)
clustered_data
```

Now that we have this information in a tidy data frame, we can make a visualization
of the cluster assignments for each point, as shown in Figure \@ref(fig:10-plot-clusters-2).

```{r 10-plot-clusters-2, fig.height = 3.25, fig.width = 4.25, fig.align = "center", fig.pos = "H", out.extra="", fig.cap = "The data colored by the cluster assignments returned by K-means."}
cluster_plot <- ggplot(clustered_data,
  aes(x = flipper_length_mm, 
      y = bill_length_mm, 
      color = .cluster), 
  size = 2) +
  geom_point() +
  labs(x = "Flipper Length (standardized)", 
       y = "Bill Length (standardized)", 
       color = "Cluster") + 
  scale_color_manual(values = c("dodgerblue3",
                                "darkorange3",  
                                "goldenrod1")) + 
  theme(text = element_text(size = 12))

cluster_plot
```

As mentioned above, we also need to select K by finding
where the "elbow" occurs in the plot of total WSSD versus the number of clusters. 
We can obtain the total WSSD (`tot.withinss`) \index{WSSD!total} from our
clustering using `broom`'s `glance` function. For example:

```{r 10-glance}
glance(penguin_clust)
```

To calculate the total WSSD for a variety of Ks, we will
create a data frame with a column named `k` with rows containing
each value of K we want to run K-means with (here, 1 to 9). 

```{r 10-choose-k-part1}
penguin_clust_ks <- tibble(k = 1:9)
penguin_clust_ks
```

Then we use `rowwise` \index{rowwise} + `mutate` to apply the `kmeans` function 
within each row to each K. 
However, given that the `kmeans` function 
returns a model object to us (not a vector),
we will need to store the results as a list column.
This works because both vectors and lists are legitimate 
data structures for data frame columns. 
To make this work, 
we have to put each model object in a list using the `list` function.
We demonstrate how to do this below:

```{r}
penguin_clust_ks <- tibble(k = 1:9) |>
  rowwise() |>
  mutate(penguin_clusts = list(kmeans(standardized_data, k)))
```

If we take a look at our data frame `penguin_clust_ks` now, 
we see that it has two columns: one with the value for K, 
and the other holding the clustering model object in a list column.

```{r}
penguin_clust_ks
```

If we wanted to get one of the clusterings out 
of the list column in the data frame,
we could use a familiar friend: `pull`.
`pull` will return to us a data frame column as a simpler data structure,
here that would be a list.
And then to extract the first item of the list, 
we can use the `pluck` function. We pass  
it the index for the element we would like to extract 
(here, `1`).

```{r}
penguin_clust_ks |>
  pull(penguin_clusts) |>
  pluck(1)
```

Next, we use `mutate` again to apply `glance` \index{glance} 
to each of the K-means clustering objects to get the clustering statistics 
(including WSSD). 
The output of `glance` is a data frame, 
and so we need to create another list column (using `list`) for this to work. 
This results in a complex data frame with 3 columns, one for K, one for the 
K-means clustering objects, and one for the clustering statistics:

```{r 10-choose-k-part3}
penguin_clust_ks <- tibble(k = 1:9) |>
  rowwise() |>
  mutate(penguin_clusts = list(kmeans(standardized_data, k)),
         glanced = list(glance(penguin_clusts)))

penguin_clust_ks
```

Finally we extract the total WSSD from the column named `glanced`. 
Given that each item in this list column is a data frame, 
we will need to use the `unnest` function 
to unpack the data frames into simpler column data types. 

```{r 10-get-total-within-sumsqs}
clustering_statistics <- penguin_clust_ks |>
  unnest(glanced)

clustering_statistics
```

Now that we have `tot.withinss` and `k` as columns in a data frame, we can make a line plot 
(Figure \@ref(fig:10-plot-choose-k)) and search for the "elbow" to find which value of K to use. 

```{r 10-plot-choose-k, fig.height = 3.25, fig.width = 4.25, fig.align = "center", fig.pos = "H", out.extra="", fig.cap = "A plot showing the total WSSD versus the number of clusters."}
elbow_plot <- ggplot(clustering_statistics, aes(x = k, y = tot.withinss)) +
  geom_point() +
  geom_line() +
  xlab("K") +
  ylab("Total within-cluster sum of squares") +
  scale_x_continuous(breaks = 1:9) + 
  theme(text = element_text(size = 12))

elbow_plot
```

It looks like 3 clusters is the right choice for this data.
But why is there a "bump" in the total WSSD plot here? 
Shouldn't total WSSD always decrease as we add more clusters? 
Technically yes, but remember:  K-means can get "stuck" in a bad solution. 
Unfortunately, for K = 8 we had an unlucky initialization
and found a bad clustering! \index{K-means!restart, nstart} 
We can help prevent finding a bad clustering 
by trying a few different random initializations 
via the `nstart` argument (Figure \@ref(fig:10-choose-k-nstart) 
shows a setup where we use 10 restarts). 
When we do this, K-means clustering will be performed 
the number of times specified by the `nstart` argument,
and R will return to us the best clustering from this.
The more times we perform K-means clustering,
the more likely we are to find a good clustering (if one exists).
What value should you choose for `nstart`? The answer is that it depends
on many factors: the size and characteristics of your data set,
as well as the speed and size of your computer.
The larger the `nstart` value the better from an analysis perspective, 
but there is a trade-off that doing many clusterings 
could take a long time.
So this is something that needs to be balanced.

```{r 10-choose-k-nstart, fig.height = 3.25, fig.width = 4.25, fig.pos = "H", out.extra="", message= FALSE, warning = FALSE, fig.align = "center", fig.cap = "A plot showing the total WSSD versus the number of clusters when K-means is run with 10 restarts."}
penguin_clust_ks <- tibble(k = 1:9) |>
  rowwise() |>
  mutate(penguin_clusts = list(kmeans(standardized_data, nstart = 10, k)),
         glanced = list(glance(penguin_clusts)))

clustering_statistics <- penguin_clust_ks |>
  unnest(glanced)

elbow_plot <- ggplot(clustering_statistics, aes(x = k, y = tot.withinss)) +
  geom_point() +
  geom_line() +
  xlab("K") +
  ylab("Total within-cluster sum of squares") +
  scale_x_continuous(breaks = 1:9) + 
  theme(text = element_text(size = 12))

elbow_plot
```

## Exercises

Practice exercises for the material covered in this chapter 
can be found in the accompanying 
[worksheets repository](https://github.com/UBC-DSCI/data-science-a-first-intro-worksheets#readme)
in the "Clustering" row.
You can launch an interactive version of the worksheet in your browser by clicking the "launch binder" button.
You can also preview a non-interactive version of the worksheet by clicking "view worksheet."
If you instead decide to download the worksheet and run it on your own machine,
make sure to follow the instructions for computer setup
found in Chapter \@ref(move-to-your-own-machine). This will ensure that the automated feedback
and guidance that the worksheets provide will function as intended.

## Additional resources
- Chapter 10 of *An Introduction to Statistical
  Learning* [@james2013introduction] provides a
  great next stop in the process of learning about clustering and unsupervised
  learning in general. In the realm of clustering specifically, it provides a
  great companion introduction to K-means, but also covers *hierarchical*
  clustering for when you expect there to be subgroups, and then subgroups within
  subgroups, etc., in your data. In the realm of more general unsupervised
  learning, it covers *principal components analysis (PCA)*, which is a very
  popular technique for reducing the number of predictors in a dataset.
