What is Image Segmentation?

Image segmentation is a computer vision task that involves partitioning a digital image into multiple segments or regions. The goal is to assign a label to every single pixel in the image, providing a much more granular understanding of the image content than object detection.



Instead of just drawing a box around a car, segmentation can outline the exact shape of the car, pixel by pixel.


Types of Image Segmentation

There are two main categories of image segmentation:

  1. Semantic Segmentation: This is the most common type. It classifies each pixel as belonging to a particular class. For example, in a street scene, it would label all pixels that are part of any car as "car," all pixels that are part of any pedestrian as "pedestrian," and so on. It doesn't distinguish between different instances of the same class.

  2. Instance Segmentation: This is a more complex task that combines the ideas of object detection and semantic segmentation. It not only classifies each pixel but also differentiates between individual instances of a class. In our street scene, it would label pixels as "car-1," "car-2," "pedestrian-1," etc.

The U-Net Architecture: The Gold Standard

While many architectures exist, the U-Net is the most famous and influential model for semantic segmentation, especially in medical imaging. Its unique design is exceptionally good at producing high-resolution output masks.


The U-Net has two main parts, forming a U-shape:


  • The Contracting Path (Encoder): This is a traditional convolutional neural network stack. It uses a series of convolutions and max-pooling layers to downsample the image. This path is responsible for capturing the context of the image—learning what objects are present.
  • The Expansive Path (Decoder): This path takes the low-resolution feature map from the encoder and uses upsampling convolutions (transposed convolutions) to gradually increase its resolution back to the original image size. This path is responsible for precise localization—learning where the objects are.

The Secret Sauce: Skip Connections The key innovation of the U-Net is the use of skip connections. These connections feed feature maps directly from the encoder path to the corresponding layer in the decoder path. This process allows the decoder to reuse the high-resolution spatial information learned in the early stages of the encoder, which would otherwise be lost during downsampling. This is critical for generating accurate, detailed segmentation masks.


Case Study: Segmenting Roads for Autonomous Driving

Let's consider a practical application: identifying the drivable area for a self-driving car.

  • The Goal: Given an image from a car's front-facing camera, we want to create a segmentation mask that classifies every pixel as either "road" or "not road."
  • The Dataset: To train a model like U-Net, you need a specialized dataset.
  • Input Data: The raw camera images ((height, width, 3)).
  • Ground Truth Labels: Manually annotated mask images ((height, width, 1)). A mask is a grayscale image where each pixel's value corresponds to a class ID (e.g., 0 for background, 1 for road, 2 for sidewalk, etc.).
  • Training: You feed the input images to the U-Net model. The model outputs a predicted mask. The loss function (e.g., cross-entropy) then compares the predicted mask to the ground truth mask, pixel by pixel, and updates the model's weights to improve its accuracy.
  • Inference: Once trained, you can feed a new, unseen image to the model. It will output a mask predicting the drivable road area in real-time. This information is critical for the car's path planning system.

Code Snippet: Conceptual U-Net Training in Keras

Python


import tensorflow as tf
from tensorflow_examples.models.pix2pix import pix2pix # Contains a U-Net generator

# --- Assume you have loaded your dataset ---
# train_images: a dataset of RGB images
# train_masks: a dataset of corresponding segmentation masks

# --- 1. Load a U-Net model ---
# This example uses a pre-built U-Net from TensorFlow's examples.
# The output channels should match the number of classes you have.
OUTPUT_CHANNELS = 3 # e.g., for background, road, sidewalk
unet_model = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='batchnorm')

# --- 2. Define the loss and compile the model ---
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

unet_model.compile(optimizer='adam',
                   loss=loss,
                   metrics=['accuracy'])

# --- 3. Train the model ---
# The model learns to map the input images to the output masks.
model_history = unet_model.fit(train_images, train_masks, epochs=20, ...)

# --- 4. Use the model for prediction ---
# new_image = ...
# predicted_mask = unet_model.predict(new_image)