Probing Vision Transformers

318

20

Sayak Paul

Added on September 14, 2024

In this repository, we provide tools to probe into the representations learned by different families of Vision Transformers (supervised pre-training with ImageNet-21k, ImageNet-1k, distillation, etc.)

Probing Vision Transformers

In this repository, we provide tools to probe into the representations learned by different families of Vision Transformers (supervised pre-training with ImageNet-21k, ImageNet-1k, distillation, etc.)

Sayak Paul

README.md

Probing ViTs

By Aritra Roy Gosthipaty and Sayak Paul (equal contribution)

In this repository, we provide tools to probe into the representations learned by different families of Vision Transformers (supervised pre-training with ImageNet-21k, ImageNet-1k, distillation, self-supervised pre-training):

  • Original ViT [1]
  • DeiT [2]
  • DINO [3]

We hope these tools will prove to be useful for the community. Please follow along with this post on keras.io for a better navigation through the repository.

Updates

Self-attention visualization

Original Image Attention Maps Attention Maps Overlayed
original image attention maps attention maps overlay

https://user-images.githubusercontent.com/36856589/162609884-8e51156e-d461-421d-9f8a-4d4e48967bd6.mp4

Original Video Source

https://user-images.githubusercontent.com/36856589/162609907-4e432dc4-a731-40f4-9a20-94e0c8f648bc.mp4

Original Video Source

Supervised salient representations

In the DINO blog post, the authors show a video with the following caption:

The original video is shown on the left. In the middle is a segmentation example generated by a supervised model, and on the right is one generated by DINO.

A screenshot of the video is as follows:

image

We obtain the attention maps generated with the supervised pre-trained model and find that they are not that salient w.r.t the DINO model. We observe a similar behaviour in our experiments as well. The figure below shows the attention heatmaps extracted with a ViT-B16 model pre-trained (supervised) using ImageNet-1k:

Dinosaur Dog

We used this Colab Notebook to conduct this experiment.

Hugging Face Spaces

You can now probe into the ViTs with your own input images.

Attention Heat Maps Attention Rollout

Visualizing mean attention distances

Methods

We don't propose any novel methods of probing the representations of neural networks. Instead we take the existing works and implement them in TensorFlow.

  • Mean attention distance [1, 4]
  • Attention Rollout [5]
  • Visualization of the learned projection filters [1]
  • Visualization of the learned positioanl embeddings
  • Attention maps from individual attention heads [3]
  • Generation of attention heatmaps from videos [3]

Another interesting repository that also visualizes ViTs in PyTorch: https://github.com/jacobgil/vit-explain.

Notes

We first implemented the above-mentioned architectures in TensorFlow and then we populated the pre-trained parameters into them using the official codebases. In order to validate this, we evaluated the implementations on the ImageNet-1k validation set and ensured that the reported top-1 accuracies matched.

We value the spirit of open-source. So, if you spot any bugs in the code or see a scope for improvement don't hesitate to open up an issue or contribute a PR. We'd very much appreciate it.

Our ViT implementations are in vit. We provide utility notebooks in the notebooks directory which contains the following:

DeiT-related code has its separate repository: https://github.com/sayakpaul/deit-tf.

Models

Here are the links to the models where the pre-trained parameters were populated:

Training and visualizing with small datasets

Coming soon!

References

[1] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale: https://arxiv.org/abs/2010.11929

[2] DeiT: https://arxiv.org/abs/2012.12877

[3] DINO: https://arxiv.org/abs/2104.14294

[4] Do Vision Transformers See Like Convolutional Neural Networks?: https://arxiv.org/abs/2108.08810

[5] Quantifying Attention Flow in Transformers: https://arxiv.org/abs/2005.00928

Acknowledgements

Related Content

Emojinator

A simple real time emoji classifier for humans.
GitHubUpdated 12 months ago

KalidoKit

Vtuber tracking for Mediapipe/Tensorflow.js Face, Eyes, Pose, and Finger models.
GitHubUpdated 20 months ago

PixPlot

PixPlot combines Tensorflow and WebGL to visualize large image collections in an interactive web environment.
GitHubUpdated 23 months ago

Deploying ML models with TFServing, Docker, GKE

This project shows how to serve a TensorFlow image classification model as RESTful and gRPC based service with TFServing, Docker, and Kubernetes.
GitHubUpdated 32 months ago

Blood-Cell-Detection-TFOD-2.0

This project demonstrates the use of TensorFlow Object Detection API to automatically detect Red Blood Cells (RBCs), White Blood Cells (WBCs), and Platelets in each image taken via microscopic images.
GitHubUpdated 43 months ago

We serve cookies on this site to analyze traffic, remember your preferences, and optimize your experience.