Hloupě horlivá heuristika

Lukáš Veškrna

K-means clustering – Lloyd’s algorithm

Published Lukáš Veškrna

In this article, we will take a look at k-means clustering. We will write an implementation of Lloyd’s algorithm, which is a naive (better algorithms are known, however, Lloyd’s algorithm is simple and very common) heuristic approach to k-means clustering. I would also like to cover some applications of k-means and some extensions that improve the algorithm.

First of all, let’s talk about what k-means clustering is and what problems it solves.


K-means has applications in data analysis, image processing and many more. For example, it can be used to create a recommendation system by separating customers (or products) into groups (clusters) and then recommending products based on the cluster the customer (or the product) belongs to (multiple clusterings over different features are built). It can be also used to gain insights over data.

Possible uses in image processing include image segmentation, color reduction or dithering. We will get to color reduction later.

It can also be used for unsupervised classification, for instance in problems where we know the number of classes, but it is difficult to get a large number of labels for our data.

So what is K-means clustering?

K-means clustering is an unsupervised learning technique used to partition dd-dimensional data points into kk clusters (yes, the same kk as in K-means).1kk is a parameter that needs to be picked depending on the problem. There are methods to choose a good kk for a problem where it is unclear, but that is beyond the scope of this article.2 Data points are assigned to clusters based on their distances from kk centroids (or means). The goal is to minimize the within-cluster variance (also called inertia or within-cluster sum of squares)1 It is defined as

i=0nminc  Ccxj2 \newcommand{\norm}[1]{\left\lVert#1\right\rVert} \sum^n_{i=0} \min_{c\ \in\ C}\norm{c - x_j}^2

where nn is the number of data points and CC is the set of centroids.1 However, it is not easy to find the solution with minimal inertia, actually, it is NP-hard8, so the purpose of Lloyd’s algorithm is to approximate the best possible locations of the centroids.

Lloyd’s algorithm

The algorithm works iteratively – first, it creates the kk centroids with some initialization method,2The initial centroids could be chosen randomly but the algorithm would not perform very good. Thus, the initial centroids are chosen from the set of the data points or more advanced methods (such as k-means++) are used. We will revisit k-means++ in the Improvements section then it assigns every data point to the closest centroid. This way, they are split into kk clusters. The next step is to move each centroid to the mean of the matching cluster, points are reassigned and the whole process is repeated.12

The assignment of the data points to closest centroids can be represented using a klzzwxh:0035 generated by the centroids. A Voronoi diagram (also called tessellation) is a division of a plane into sets of points that are closest to a particular generator point. A voronoi diagram is defined by a set of these generator points. In our case, however, we won't use a Voronoi diagram, since a simple lookup will do.klzzwxh:0034
Figure 1: The assignment of the data points to closest centroids can be represented using a Voronoi diagram generated by the centroids. A Voronoi diagram (also called tessellation) is a division of a plane into sets of points that are closest to a particular generator point. A voronoi diagram is defined by a set of these generator points. In our case, however, we won’t use a Voronoi diagram, since a simple lookup will do.6

After some amount of iterations, the algorithm will converge, however it can get stuck in a local optimum and that’s why better implementations run it multiple times with different initial conditions and pick the best result.1 We will implement this in the Improvements section.

To summarize:

  1. Choose kk initial centroids
  2. Assign every sample to the closest centroid
  3. Move the centroids to the means of their corresponding clusters
  4. Repeat steps 2-3 until the algorithm converges or reaches a maximum number of steps.


Our implementation will be simple and slow, with the sole purpose of understanding the algorithm by implementing it.

Sample data

At first, we create three clusters of random sample data in two dimensions.

%matplotlib inline
%config InlineBackend.figure_format = "svg"

import numpy as np
import matplotlib.pyplot as plt
from numpy.random import default_rng

plt.rc("axes", facecolor=(1, 1, 1, 0)) # Transparent plot background for dark mode viewers

rng = default_rng(0)

datapoints = np.zeros((150, 2)) # 90 datapoints with 2 dimensions
labels = np.zeros(150,)

# First cluster
datapoints[:50, 0] = rng.normal(6, 0.5, (50,))
datapoints[:50, 1] = rng.normal(3, 0.5, (50,))
labels[:50] = 0
# Second cluster
datapoints[50:100, 0] = rng.normal(5, 0.5, (50,))
datapoints[50:100, 1] = rng.normal(0, 0.5, (50,))
labels[50:100] = 1
# Third cluster
datapoints[100:, 0] = rng.normal(2, 0.5, (50,))
datapoints[100:, 1] = rng.normal(1, 0.5, (50,))
labels[100:] = 2

# Plot it
plt.scatter(datapoints[:, 0], datapoints[:, 1], c=labels)



Let’s create an initialization function that chooses the initial centroids randomly from the data points. Our initialization function also won’t pick a data point twice, because that would cause a cluster to degenerate and the data points would be clustered into less than kk clusters. (this could also happen if the dataset contains duplicates)3

def init_random(k, dataset, rng=default_rng()):
    """Choose `k` centroids from the supplied `dataset` randomly"""
    return dataset[rng.choice(dataset.shape[0], size=k, replace=False)]

The predict function

Next, we write a function that assigns points from a dataset to corresponding clusters given the centroids. We will need a function cluster_dist to compute the distance matrix of every sample to every cluster.3Because we are using numpy, computing a distance matrix is faster than using a python loop to find the closest centroid. numpy is written in C and takes advantage of many optimizations, such as SIMD

def cluster_dist(centroids, dataset):
    Compute the squared distance from every cluster 
    for every sample in `dataset`.
    Return a distance matrix with shape (n, k),
    where n is the number of samples in `dataset`.
    dist = np.repeat(dataset[:, np.newaxis, :], centroids.shape[0], axis=1)
    dist -= centroids
    return (dist**2).sum(axis=2)

def predict(centroids, dataset):
    Find the corresponding cluster indices
    given the `centroids` for every datapoint from `dataset`
    return cluster_dist(centroids, dataset).argmin(axis=1)


Now we create kmeans_step which performs one iteration of the algorithm.

def kmeans_step(centroids, dataset):
    """Perform a single iteration of the algorithm. Return updated centroids."""
    classes = predict(centroids, dataset)
    return np.array([dataset[classes==cls].mean(axis=0) for cls in range(len(centroids))])

Training function

It’s time to write the generator kmeans_base. It yields and updates the centroids for every step of training on dataset.

from itertools import islice

def kmeans_base(k, dataset, init=init_random, rng=default_rng()):
    A generator that fits `k`-means on `dataset`.
    Centroids are initialized with `init`.
    centroids = init(k, dataset, rng=rng)
    while True:
        yield centroids
        centroids = kmeans_step(centroids, dataset)

until_convergence and until_max_iter take from the generator while certain conditions are satisfied.4 We use these functions to create a combined function kmeans_iter.

def until_convergence(it, tol=1e-4):
    Yield values from `it` until the converge (i.e. stop changing)
    within tolerance `tol`.
    prev = None
    cur = next(it)
    while prev is None or np.linalg.norm(cur-prev) > tol:
        yield cur
        prev = cur
        cur = next(it)

def until_max_iter(it, max_iter=300):
    Yield at most `max_iter` values from `it`. 
    `max_iter` can be None, therefore no limit.
    return islice(it, 0, max_iter)

def kmeans_iter(k, dataset, tol=1e-4, max_iter=300, init=init_random, rng=default_rng()):
    Fits `k`-means on `dataset` until it converges
    within tolerance `tol` or reaches `max_iter`. 
    Centroids are initialized with `init`.
    Yields centroids every iteration.
    yield from until_max_iter(until_convergence(
                kmeans_base(k, dataset, init=init, rng=rng), 

Finally, the main kmeans function, which fits k-means on a given dataset and returns the final centroid locations.

def kmeans(k, dataset, tol=1e-4, max_iter=300, init=init_random, rng=default_rng()):
    Fits `k`-means on `dataset` until it converges
    within tolerance `tol` or reaches `max_iter`. 
    Centroids are initialized with `init`.
    Return the trained centroids.
    for x in kmeans_iter(k, dataset, tol=1e-4, max_iter=300, init=init, rng=rng):
    return x

Running on test data

model = kmeans(3, datapoints, rng=rng)
labels_pred = predict(model, datapoints)

plt.scatter(datapoints[:, 0], datapoints[:, 1], c=labels_pred)



Now let’s create an animation of the training process. This time we will just use uniformly distributed random points.4One interesting observation about running k-means on random points is that you can see the underlying Voronoi diagram arise.

from matplotlib.animation import FuncAnimation
from IPython.display import HTML

rng = default_rng(0)
datapoints = rng.uniform(0, 5, (400, 2))
steps = list(kmeans_iter(6, datapoints, rng=rng))

def frame(n):
    model = steps[n]
    plt.scatter(datapoints[:, 0], datapoints[:, 1], c=predict(model, datapoints))
    plt.scatter(model[:, 0], model[:, 1], marker="x", c="black")

anim = FuncAnimation(plt.figure(), frame, frames=len(steps), interval=100)


Applying it

Now let’s try applying our k-means to some problems. Of course, our implementation of k-means is very clumsy and these problems would be solved better using a different one, but in this article we will use our implementation to see if it is useful enough. In the real world, you should probably use (for example) sklearn.cluster.KMeans instead.

Color reduction

The first problem we will try to solve is color reduction. This can be done for viewing on some medium that does not support enough colors, for compression or for artistic purposes.

The first step we take is to fetch some image. Image taken from 5.

import imageio

img = imageio.imread("photo.jpg")
(4288, 2848, 3)

Original image

Every pixel can be represented as a 3-dimensional vector. The goal is to cluster them, so we find the optimal colors to use, which will be the centroids. Thus, the number of clusters, kk, will be equal to the number of reduced colors.

We could cluster all the colors in the image, but our implementation is too slow for that. Instead, we sample a random subset of the pixels and cluster them. Then we decide the cluster for every pixel in the image and use it’s centroid as color.

rng = default_rng(0)

h, w, c = img.shape
assert c == 3

normalized = img / 255
pixels = np.reshape(normalized, (h*w, c))
sampled_pixels = pixels[rng.integers(0, h*w, (SAMPLE_PIXELS,))]

model = kmeans(NUM_COLORS, sampled_pixels, rng=rng)

reduced_img = (model[predict(model, pixels)].reshape((h, w, c)) * 255).astype(np.uint8)
imageio.imwrite("reduced.jpg", reduced_img)

Reduced color image

And that’s our image using only eight colors. And it’s still pretty recognizable.

Andy Warhol inspired pictures

One of Andy Warhol's artworks
Figure 2: One of Andy Warhol’s artworks
For one of my school art assignments I used this color reduction technique to create pictures inspired by Andy Warhol’s artworks. He made pictures of celebrities and often included a few different variations in one picture.

I picked two famous people of the present day, Taylor Swift and Barack Obama, used k-means to reduce the colors and changed the hue in a few different ways.

Classifying handwritten digits

This time we will try applying k-means for unsupervised recognition of the Handwritten Digits Data Set from UCI Machine Learning Repository.7 We will use a scikit-learn convenience function to load the dataset.

There are 1797 digits and every digit is 8×8=648\times8 = 64 pixels. Each pixel can have values 0-15.

from sklearn.datasets import load_digits

digits = load_digits()

print("Shape:", digits.data.shape)
print("Format:", digits.data[0])

fig, axs = plt.subplots(3, 3)
for i, ax in enumerate(axs.reshape(-1)):
    ax.imshow(digits.data[i].reshape(8, 8), cmap="gray_r")
Shape: (1797, 64)
Format: [ 0.  0.  5. 13.  9.  1.  0.  0.  0.  0. 13. 15. 10. 15.  5.  0.  0.  3.
 15.  2.  0. 11.  8.  0.  0.  4. 12.  0.  0.  8.  8.  0.  0.  5.  8.  0.
  0.  9.  8.  0.  0.  4. 11.  0.  1. 12.  7.  0.  0.  2. 14.  5. 10. 12.
  0.  0.  0.  0.  6. 13. 10.  0.  0.  0.]


The next step is to cluster the digits with k-means for k=10k = 10 (ten types of digits).


for i, model in enumerate(kmeans_iter(CLASSES, digits.data, rng=default_rng(0))):

print(f"Converged after {i} steps.")
Converged after 14 steps.

Because our classification was unsupervised, we need to find out which cluster belongs to which class. We will determine the cluster for all of the digits and find the most common digit among those clusters using mode.

from scipy.stats import mode

def assign_labels(clusters, target, num_classes):
    cluster_labels = np.zeros((num_classes,)) # A mapping (cluster_id -> label)

    for cluster_idx in range(num_classes):
        real_labels = target[clusters == cluster_idx]

        # The most common real label in the cluster
        cluster_labels[cluster_idx] = mode(real_labels).mode 

    return cluster_labels

clusters = predict(model, digits.data)
cluster_labels = assign_labels(clusters, digits.target, CLASSES)

predicted = cluster_labels[clusters] # Map the clusters to corresponding labels
correct = (predicted == digits.target).sum()
total = len(predicted)
accuracy = correct/total
print(f"Correctly classified {correct} out of {total}.\nAccuracy: {accuracy}")
Correctly classified 1424 out of 1797.
Accuracy: 0.7924318308291597

An accuracy of 79% is pretty good, if you take in account that the classifier is completely unsupervised and achieved this score only by grouping the digits into 10 clusters.

To get more insights, we can generate a confusion matrix. A confusion matrix tells us how many samples for every label got classified as every other label.

def confusion_matrix(predicted, target, num_classes):
    """Plot a confusion matrix using matplotlib."""
    mat = np.zeros((num_classes, num_classes), dtype=np.int)

    ax = plt.axes()

    for i in range(num_classes):
        for j in range(num_classes):
            mat[i, j] = ((target == i) & (predicted == j)).sum()

            text = ax.text(j, i, mat[i, j], ha="center", va="center", color="w")



confusion_matrix(predicted, digits.target, CLASSES)


As you can see, the k-means classifier has some trouble distinguishing 1 and 8.


There are many improvements, extensions and modifications of the k-means algorithm. Let’s take look at the most common ones.

Multiple runs

The algorithm can get stuck in a local optimum. To improve this, we can run it multiple times with different initial conditions and pick the best result (with minimal within-cluster variance).

def kmeans_multiple(k, dataset, n_init=10, tol=1e-4, max_iter=300, 
                    init=init_random, rng=default_rng()):
    Run `k`-means on `dataset` `n_init` times 
    until convergence within `tol` 
    or until it reaches `max_iter` iterations.
    Then we pick the result with minimal within-cluster variance.
    results = (kmeans(k, dataset, 
                     tol=tol, max_iter=max_iter, 
                     init=init, rng=rng) for _ in range(n_init))

    objectives = (cluster_dist(r, dataset).min(axis=1).sum() for r in results)
    return min(zip(objectives, results))[1]

Now let’s train the modified algorithm on the handwritten digit recognition problem and evaluate it.

model = kmeans_multiple(CLASSES, digits.data, rng=default_rng(0))

def evaluate_digits(model):
    clusters = predict(model, digits.data)
    cluster_labels = assign_labels(clusters, digits.target, CLASSES)
    predicted = cluster_labels[clusters]

    correct = (predicted == digits.target).sum()
    total = len(predicted)
    accuracy = correct/total
    print(f"Correctly classified {correct} out of {total}.\nAccuracy: {accuracy}")

    confusion_matrix(predicted, digits.target, CLASSES)

Correctly classified 1563 out of 1797.
Accuracy: 0.8697829716193656


That’s nearly an 8% improvement achieved only by running multiple times with different setups. The model is now better at separating 1s and 8s. 87% is a pretty good score.


k-means++ is different initialization method. k-means++ chooses the first centroid randomly and then successively selects rest of the centroids. The centroids are picked randomly from the data points, however with probabilty proportional to the squared distance from the nearest already chosen centroid.8

def init_kmeanspp(k, dataset, rng=default_rng()):
    centroids = np.empty((k, dataset.shape[1]))
    centroids[0] = init_random(1, dataset, rng=rng)[0]

    for i in range(1, k):
        w = cluster_dist(centroids[:i], dataset).min(axis=1)
        w /= w.sum()
        centroids[i] = dataset[rng.choice(np.arange(dataset.shape[0]), p=w)]

    return centroids

In some cases kmeans++ significantly improves the final error. k-means++ initialization is slower, but, in the end, it reduces computation time, because the algorithm converges in less iterations.8 To test this, we will compare init_random and init_kmeanspp on our color reduction problem, since it is the most time expensive problem covered.

for i, model in enumerate(kmeans_iter(NUM_COLORS, sampled_pixels, rng=default_rng(0))):

print(f"init_random: Converged after {i} steps.")

for i, model in enumerate(kmeans_iter(NUM_COLORS, sampled_pixels, init=init_kmeanspp, rng=default_rng(0))):

print(f"kmeans++: Converged after {i} steps.")
init_random: Converged after 135 steps.
kmeans++: Converged after 89 steps.


I’ve been reading about k-means for some time and thought that I could understand it better by writing about it. I hope this article could be a useful introduction to k-means for somebody like me.

  1. scikit-learn Userguide, chapter 2.3. Clustering [https://scikit-learn.org/stable/modules/clustering.html#k-means

  2. Jake VanderPlas, Python Data Science Handbook [https://jakevdp.github.io/PythonDataScienceHandbook/05.11-k-means.html

  3. https://stackoverflow.com/questions/24919346/mini-batch-k-means-returns-less-than-k-clusters 

  4. Inspired by Joel Grus’ awesome talk Learning Data Science Using Functional Python 

  5. Photo by S L on Unsplash 

  6. Source: https://en.wikipedia.org/wiki/Voronoi_diagram#/media/File:Euclidean_Voronoi_diagram.svg 

  7. Dua, D. and Graff, C. (2019). UCI Machine Learning Repository [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California, School of Information and Computer Science. 

  8. Wikipedia, k-means++ https://en.wikipedia.org/wiki/K-means%2B%2B