import torch from torch import nn from torch.nn import functional as F from attention import SelfAttention class CLIPEmbedding(nn.Module): def __init__(self, n_vocab: int, n_embd: int, n_token: int): super().__init__() self.token_embedding = nn.Embedding(n_vocab, n_embd) self.position_embedding = nn.Parameter(torch.zeros((n_token, n_embd))) def forward(self, tokens): x = self.token_embedding(tokens) x += self.position_embedding return x class CLIPLayer(nn.Module): def __init__(self, n_head: int, n_embd: int): super().__init__() self.layernorm_1 = nn.LayerNorm(n_embd) self.attention = SelfAttention(n_head, n_embd) self.layernorm_2 = nn.LayerNorm(n_embd) self.linear_1 = nn.Linear(n_embd, 4 * n_embd) self.linear_2 = nn.Linear(4 * n_embd, n_embd) def forward(self, x): residue = x x = self.layernorm_1(x) x = self.attention(x, causal_mask=True) x += residue residue = x x = self.layernorm_2(x) x = self.linear_1(x) x = x * torch.sigmoid(1.702 * x) x = self.linear_2(x) x += residue return x class CLIP(nn.Module): def __init__(self): super().__init__() self.embedding = CLIPEmbedding(49408, 768, 77) self.layers = nn.ModuleList([ CLIPLayer(12, 768) for i in range(12) ]) self.layernorm = nn.LayerNorm(768) def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor: tokens = tokens.type(torch.long) state = self.embedding(tokens) for layer in self.layers: state = layer(state) output = self.layernorm(state) return output