Introduction
CLIP - Contrastive Language-Image Pre-training
Takes images and text. Then connects them in a non-generative way. The important thing here is that it is trained on full sentences instead of single classes like car, dog, etc. The intuition is that when trained on whole sentences, the model can learn a lot more things and finds some pattern between images and texts.
https://arxiv.org/pdf/2103.00020.pdf
Zero-Shot Classification
Zero-shot text classification is a task in natural language processing where a model is trained on a set of labeled examples but is then able to classify new examples from previously unseen classes. [1]
CLIP Contrastive training objective
1. Encoders
Text Encoder: This is a transformer. We can use CLIPEncoder. CLIPEncoder has 12 transformer blocks and 8 mult head attention blocks
Image Encoder: Multiple Encoders have been tested out. A bunch of variants of Resnet, variants of VITs. That is why we see multiple different variants of these models
Training
An image is processed through an image encoder, yielding a vector in a latent space. Let I1, I2, and I3 denote the vectors for Image 1, Image 2, and Image 3, forming a mini-batch of images.
Simultaneously, the text input is fed into a text encoder, generating representations for the text. T1, T2, etc., represent the vectors for Text 1, Text 2, and so forth.
During batch training, I1 pairs with T1, I2 pairs with T2, and so forth.
The model is then queried to determine which representation (T1, T2, etc.) is most suitable for a given image. This defines the contrastive training objective, wherein, despite knowing the correct pairings from the training data (e.g., I1 with T1), the model is trained to be maximally close to the correct match and minimally/far away from all other possibilities. The objective contrasts known associations (diagonal elements in the matrix) with those that are known to be mismatched.
For each image, a classification task is undertaken (illustrated by the red rectangle for Image 2), and similarly, for each text, a classification task is performed (depicted by the blue rectangle for Text 3).
The objective is to maximize the inner product of paired items while minimizing the inner product of mismatched ones. Subsequently, a softmax classification is performed in the directions indicated by the red and blue boxes in the aforementioned image. This symmetric loss perspective is applied both to the image and the text, essentially framing the problem as a classification task observed from two distinct viewpoints.
The efficacy of this approach is contingent upon the use of sufficiently large mini-batches. As the mini-batch size approaches the size of the entire dataset, the representations become increasingly detailed. In essence, larger mini-batches contribute to more nuanced and comprehensive representations.
Inference
At inference time, we pass the image through the image encoder to get the latent vector representation of the image, I1.
The labels are encoded using the Text encoder, to get T1, T2, T3 and so on.
We then do the inner product and check which of this is closest. Then that is how we obtain the label.
Projection head
Having converted our images and texts into fixed-size vectors (768 for images and 512 for texts), the next step is to project them into a shared space with comparable dimensions. This process creates a common ground for both images and texts, enabling us to easily compare and differentiate between relevant and irrelevant pairs. The forthcoming code is designed to transform the 768-dimensional image vectors and 512-dimensional text vectors into a unified 512-dimensional space (projection_dim) for effective comparison.
import torch
from torch import nn
class ProjectionHead(nn.Module):
def __init__(
self,
embedding_dim,
projection_dim=CFG.projection_dim,
dropout=CFG.dropout
):
super().__init__()
self.projection = nn.Linear(embedding_dim, projection_dim)
self.gelu = nn.GELU()
self.fc = nn.Linear(projection_dim, projection_dim)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(projection_dim)
def forward(self, x):
projected = self.projection(x)
x = self.gelu(projected)
x = self.fc(x)
x = self.dropout(x)
x = x + projected
x = self.layer_norm(x)
return x
CLIP Model
import torch
from torch import nn
import torch.nn.functional as F
import config as CFG
from modules import ImageEncoder, TextEncoder, ProjectionHead
class CLIPModel(nn.Module):
def __init__(
self,
temperature=CFG.temperature,
image_embedding=CFG.image_embedding,
text_embedding=CFG.text_embedding,
):
super().__init__()
self.image_encoder = ImageEncoder()
self.text_encoder = TextEncoder()
self.image_projection = ProjectionHead(embedding_dim=image_embedding)
self.text_projection = ProjectionHead(embedding_dim=text_embedding)
self.temperature = temperature
def forward(self, batch):
# Getting Image and Text Features
image_features = self.image_encoder(batch["image"])
text_features = self.text_encoder(
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
)
# Getting Image and Text Embeddings (with same dimension)
image_embeddings = self.image_projection(image_features)
text_embeddings = self.text_projection(text_features)
# Calculating the Loss
logits = (text_embeddings @ image_embeddings.T) / self.temperature
images_similarity = image_embeddings @ image_embeddings.T
texts_similarity = text_embeddings @ text_embeddings.T
targets = F.softmax(
(images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
)
texts_loss = cross_entropy(logits, targets, reduction='none')
images_loss = cross_entropy(logits.T, targets.T, reduction='none')
loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size)
return loss.mean()
def cross_entropy(preds, targets, reduction='none'):
log_softmax = nn.LogSoftmax(dim=-1)
loss = (-targets * log_softmax(preds)).sum(1)
if reduction == "none":
return loss
elif reduction == "mean":
return loss.mean()
Here we will use the previous modules that we built to implement the main model. The init function is self-explanatory. In the forward function, we first encode the images and texts separately into fixed size vectors (with different dimensionalities). After that, using separate projection modules we project them to that shared world (space) that I talked about previously. Here the encodings will become of similar shape (512 in our case). After that we will compute the loss. Again I recommend reading CLIP paper to get it better but I’ll try my best to explain this part.
In Linear Algebra, one common way to measure if two vectors are of similar characteristics (they are like each other) is to calculate their dot product (multiplying the matching entries and take the sum of them); if the final number is big, they are alike and if it is small they are not (relatively speaking)!
Okay! What I just said is the most important thing to have in mind to understand this loss function. Let’s continue. We talked about two vectors, but, what do we have here? We have image_embeddings, a matrix with shape (batch_size, 512) and text_embeddings with shape (batch_size, 512). Easy enough! it means we have two groups of vectors instead of two single vectors. How do we measure how similar two groups of vectors (two matrices) are to each other? Again, with dot product (@ operator in PyTorch does the dot product or matrix multiplication in this case). To be able to multiply these two matrices together, we transpose the second one. Okay, we get a matrix with shape (batch_size, batch_size) which we will call logits. (temperature is equal to 1.0 in our case, so, it does not make a difference. You can play with it and see what difference it makes. Also look at the paper to see why it is here!).
I hope you are still with me! If not it’s okay, just review the code and check their shapes. Now that we have our logits, we need targets. I need to say that there is a more straight forward way to obtain targets but I had to do this for our case (I’ll talk about why in a next paragraph).
Let’s consider what we hope that this model learns: we want it to learn “similar representations (vectors)” for a given image and the caption describing it. Meaning that either we give it an image or the text describing it, we want it to produce same 512 sized vectors for both.
So, in the best case scenario, text_embeddings and image_embedding matricies should be the same because they are describing similar things. Let’s think now: if this happens, what would the logits matrix be like? Let’s see with a simple example!
import torch
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
batch_size = 4
dim = 256
embeddings = torch.randn(batch_size, dim)
out = embeddings @ embeddings.T
print(F.softmax(out, dim=-1))
-----------
# tensor([[1., 0., 0., 0.],
# [0., 1., 0., 0.],
# [0., 0., 1., 0.],
# [0., 0., 0., 1.]])
So logits, in the best case, will be a matrix that if we take its softmax, will have 1.0s in the diagonal (An identity matrix to call it with fancy words!). As the loss function’s job is to make model’s predictions similar to targets (at least in most cases!), we want such a matrix as our target. That’s the reason why we are calculating images_similarity and texts_similarity matrices in the code block above.
Now that we’ve got our targets matrix, we will use simple cross entropy to calculate the actual loss. I’ve written the full matrix form of cross entropy as a function which you can see in the bottom of the code block. Okay! We are done! Wasn’t it simple?! Alright, you can ignore the next paragraph but if you are curious, there is an important note in that.
References
https://huggingface.co/tasks/zero-shot-classification
https://towardsdatascience.com/simple-implementation-of-openai-clip-model-a-tutorial-ace6ff01d9f2
Configs: