Transformers in Vision

Procedure

The objective of this experiment is to apply a Vision Transformer (ViT) model to an image classification task and to study how attention-based mechanisms and patch representations enable effective learning from visual data. This experiment emphasizes understanding pretrained transformer models, fine-tuning strategies, and visualization of attention maps using a subset of the CIFAR-10 dataset.

1. Import Required Libraries

  • PyTorch: for building and training the deep learning model.
  • torchvision: for dataset loading and image transformations (plus pretrained ViT utilities).
  • NumPy + matplotlib: NumPy for numerical operations, matplotlib for visualization/plots.

2. Dataset Loading and Description

  • Load the CIFAR-10 dataset using torchvision dataset utilities.
  • The dataset consists of:
    • 60,000 RGB images of size 32 × 32 pixels
    • 10 object classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and truck
    • 50,000 training images and 10,000 test images
  • To reduce computational cost and training time, a subset of CIFAR-10 is selected in the code while maintaining class diversity.
  • Since Vision Transformers expect larger input resolutions, images are resized to match the pretrained ViT input size before being passed to the model.

3. Data Preprocessing

  • Convert images into a tensor format suitable for the transformer input.
  • Resize images to the resolution required by the pretrained Vision Transformer.
  • Normalize images using the mean and standard deviation values associated with the pretrained ViT model to ensure compatibility and stable training.

4. Define Class Labels

  • Initialize a tuple containing the 10 CIFAR-10 class names.
  • This mapping is used to convert numerical class labels into human-readable class names during visualization and evaluation.

5. Load CIFAR-10 Dataset

  • Load the CIFAR-10 training and test datasets using torchvision dataset loaders.
  • The dataset consists of:
    • 60,000 total RGB images
    • Image size: 32 × 32 pixels
    • 10 object classes

Number of classes: 10

  • 0 → airplane

  • 1 → automobile

  • 2 → bird

  • 3 → cat

  • 4 → deer

  • 5 → dog

  • 6 → frog

  • 7 → horse

  • 8 → ship

  • 9 → truck

  • Download the dataset automatically if it is not already available locally.

6. Create a CIFAR-10 Subset

  • Select a subset of the CIFAR-10 dataset to reduce computational requirements.
  • Ensure that the subset maintains class diversity so that all categories are represented.
  • This step allows faster experimentation while preserving the learning behavior of the Vision Transformer.

7. Define Image Transformations

  • Resize input images to the resolution expected by the pretrained Vision Transformer.
  • Convert images to tensor format.
  • Normalize images using the mean and standard deviation values required by the pretrained ViT model.
  • These transformations ensure compatibility between CIFAR-10 images and the pretrained model.

8. Create Data Loaders

  • Create training and test DataLoader objects using the transformed datasets.
  • Specify an appropriate batch size for efficient training.
  • Enable shuffling for the training data to prevent learning order bias.
  • Use parallel data loading where supported to improve performance.

9. Load Pretrained Vision Transformer Model

  • Load a pretrained Vision Transformer (ViT) model using torchvision or timm utilities.
  • The model is initialized with weights trained on a large-scale image dataset.
  • Replace or reconfigure the final classification head to output predictions for 10 CIFAR-10 classes.

10. Freeze and Unfreeze Model Layers

  • Freeze selected transformer layers to prevent their weights from updating during training.
  • Allow the classification head (and optionally the last transformer layers) to be trainable.
  • This fine-tuning strategy improves performance while reducing training time.

11. Use Loss Function

  • Define the cross-entropy loss function for multi-class classification.
  • This loss function measures the difference between predicted class probabilities and true class labels.

12. Use Optimizer

  • Configure an optimizer to update trainable model parameters.
  • Set an appropriate learning rate suitable for fine-tuning a pretrained model.
  • The optimizer controls how model weights are updated during backpropagation.

13. Training Loop Implementation

  • Iterate over the training dataset for a fixed number of epochs.
  • For each batch:
    • Perform a forward pass through the Vision Transformer.
    • Compute the loss using predicted outputs and true labels.
    • Perform backpropagation to compute gradients.
    • Update model parameters using the optimizer.
  • Track training loss and accuracy for each epoch.

14. Model Evaluation

  • Switch the model to evaluation mode.
  • Disable gradient computation to reduce memory usage.
  • Evaluate the trained model on the test dataset.
  • Compute classification accuracy to assess model generalization.

15. Extract Self-Attention Weights

  • Access attention weights from the transformer encoder layers.
  • Extract attention matrices corresponding to image patches.
  • These weights represent how each patch attends to other patches in the image.

16. Visualization of Patch Attention Maps

  • Select sample images from the test dataset.
  • Visualize patch-level attention maps overlaid on the input image.
  • Analyze which image regions the Vision Transformer focuses on while making predictions.