camillebrl's picture
Upload InternLMXComposer2ForCausalLM
ca6365b verified
raw
history blame contribute delete
No virus
7.23 kB
import torch
import torch.nn as nn
import re
import math
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
def build_vision_tower():
vision_tower = 'openai/clip-vit-large-patch14-336'
return CLIPVisionTower(vision_tower)
def build_vision_projector():
projector_type = 'mlp2x_gelu'
mm_hidden_size = 4096
mid_hidden_size = 4096
hidden_size = 4096
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(mm_hidden_size, mid_hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(mid_hidden_size, mid_hidden_size))
return nn.Sequential(*modules)
if projector_type == 'identity':
return IdentityMap()
raise ValueError(f'Unknown projector type: {projector_type}')
class IdentityMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_projector_type": 'identity'}
class CLIPVisionTower(nn.Module):
def __init__(self, vision_tower):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.select_layer = -1
self.select_feature = 'patch'
self.load_model()
def load_model(self):
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
self.vision_tower.requires_grad_(False)
self.is_loaded = True
def resize_pos(self):
print ('Dummy Resized')
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == 'patch':
image_features = image_features[:, 1:]
elif self.select_feature == 'cls_patch':
image_features = image_features
else:
raise ValueError(f'Unexpected select feature: {self.select_feature}')
return image_features
def forward(self, images, glb_GN, sub_GN):
if not self.is_loaded:
self.load_model()
assert type(images) is list
shapes = []
input_imgs = []
for img in images:
_, C, H, W = img.shape
shapes.append([H//336, W//336])
sub_img = img.reshape(1,3,H//336,336,W//336,336).permute(0,2,4,1,3,5).reshape(-1,3,336,336).contiguous()
glb_img = torch.nn.functional.interpolate(img.float(), size=(336,336), mode='bicubic',).to(sub_img.dtype)
input_imgs.append(glb_img)
input_imgs.append(sub_img)
input_imgs = torch.cat(input_imgs, dim=0)
image_forward_outs = self.vision_tower(input_imgs.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
image_features = self.feature_select(image_forward_outs).to(input_imgs.dtype) ### B*?, N, C
_, N, C = image_features.shape
H = int(math.sqrt(N))
assert N == 24 ** 2
output_imgs = []
output_len = []
for [h, w] in shapes:
B_ = h*w
glb_img = image_features[:1] ### 1, N, C
glb_img = glb_img.reshape(1,H,H,C).reshape(1,H//2,2,H//2,2,C).contiguous().permute(0,1,3,2,4,5).reshape(1,H//2,H//2,4*C).contiguous()
temp_glb_GN = sub_GN.repeat(1, H//2, 1, 1)
glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape(1,-1,4*C)
sub_img = image_features[1:1+B_] ### ?, N, C
sub_img = sub_img.reshape(B_,H,H,C).reshape(B_,H//2,2,H//2,2,C).contiguous().permute(0,1,3,2,4,5).reshape(B_,-1,4*C).contiguous()
sub_img = sub_img.reshape(1, h, w, 12, 12, -1).permute(0,1,3,2,4,5).reshape(1,h*12,w*12,4*C)
temp_sub_GN = sub_GN.repeat(1, h*12, 1, 1)
sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(1,-1,4*C)
output_imgs.append(torch.cat([glb_img, glb_GN, sub_img], dim=1))
temp_len = int((h*w+1)*144 + 1 + (h+1)*12)
assert temp_len == output_imgs[-1].shape[1]
output_len.append(temp_len)
image_features = image_features[1+h*w:]
output_imgs = torch.cat(output_imgs, dim=1)
return output_imgs, output_len
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
class PLoRA(nn.Linear):
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
lora_r=8,
lora_alpha=16,
lora_dropout=0.05,
lora_len=0,
**kwargs) -> None:
super().__init__(in_features, out_features, bias, device, dtype)
self.lora_r = lora_r
self.lora_alpha = lora_alpha
self.lora_len = lora_len
if lora_dropout > 0.:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
self.lora_scaling = self.lora_alpha / self.lora_r
self.Plora_A = nn.Linear(in_features,
self.lora_r,
bias=False,
device=device,
dtype=dtype)
self.Plora_B = nn.Linear(self.lora_r,
out_features,
bias=False,
device=device,
dtype=dtype)
self.reset_parameters()
def reset_parameters(self):
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)
#print ("lora weight init {} {}".format(torch.mean(self.lora_A.weight), torch.mean(self.lora_B.weight)))
def forward(self, x, im_mask=None):
B, N, C = x.shape
x = x.reshape(-1, C)
im_mask = im_mask.view(-1)
res = super().forward(x)
if im_mask is not None:
if torch.sum(im_mask) > 0:
part_x = x[im_mask]
res[im_mask] += self.Plora_B(self.Plora_A(
self.lora_dropout(part_x))) * self.lora_scaling
else:
part_x = x[:1]
res[:1] += self.Plora_B(self.Plora_A(
self.lora_dropout(part_x))) * 0
return res.reshape(B, N, -1)