While K-Means performs hard clustering (a data point belongs to exactly one cluster), a Gaussian Mixture Model (GMM) performs soft clustering. It provides a more nuanced view by calculating the probability of each point belonging to each cluster.
The Core Idea
A GMM assumes that the data points are generated from a "mixture" of several Gaussian (normal) distributions, each with its own mean and covariance. The goal of the algorithm is to find the parameters of these underlying Gaussian distributions that best fit the data.
Because a GMM models clusters as distributions, it's not limited to circular shapes like K-Means. By adjusting the covariance parameter, it can effectively model elliptical, elongated, and rotated clusters.
Expectation-Maximization (EM) Algorithm
GMMs are typically fitted using an efficient algorithm called Expectation-Maximization (EM). It's an iterative process with two repeating steps:
- E-Step (Expectation): For each data point, calculate the probability that it was generated by each of the component Gaussian distributions. This is the "soft assignment" step. A point in the middle of two overlapping clusters might get a probability of 0.5 for each.
- M-Step (Maximization): Using the probabilities calculated in the E-step as weights, update the parameters (mean, covariance, and weight) of each Gaussian distribution to maximize the likelihood of the data. In simple terms, the center of a Gaussian will shift towards the points that have a high probability of belonging to it.
These two steps are repeated until the parameters of the distributions stabilize.
GMM vs. K-Means
- Assignments: GMM is soft (probabilistic), K-Means is hard (deterministic).
- Cluster Shape: GMM can handle elliptical clusters due to its covariance parameter. K-Means assumes spherical clusters.
- Complexity: GMM is more computationally expensive than K-Means.
Because of its probabilistic nature, a GMM can be a powerful tool for density estimation as well as clustering.
Python
# Python code with scikit-learn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
from sklearn.datasets import make_blobs
# Generate some data with elongated clusters, which K-Means would struggle with
X, y_true = make_blobs(n_samples=400, centers=4,
                       cluster_std=0.7, random_state=42)
X = np.dot(X, np.random.RandomState(0).randn(2, 2)) # Stretch and rotate the data
# Fit a Gaussian Mixture Model
gmm = GaussianMixture(n_components=4, random_state=42)
gmm.fit(X)
y_gmm = gmm.predict(X)
# You can also get the probabilities for each point
probabilities = gmm.predict_proba(X)
# Plot the results
plt.figure(figsize=(8, 5))
plt.scatter(X[:, 0], X[:, 1], c=y_gmm, s=40, cmap='viridis', zorder=2)
plt.title('Gaussian Mixture Model Clustering')
plt.show()
# Print the probabilities for the first 5 points
print("Probabilities for the first 5 data points:\n", probabilities[:5].round(3))