Skip to content

Commit

Permalink
Implement continue finetuning.
Browse files Browse the repository at this point in the history
  • Loading branch information
haotian-liu committed Oct 18, 2023
1 parent f1a7b36 commit 232302e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
18 changes: 13 additions & 5 deletions llava/model/llava_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,28 @@ def initialize_vision_modules(self, model_args, fsdp=None):

self.config.mm_vision_tower = vision_tower

vision_tower = build_vision_tower(model_args)
if self.get_vision_tower() is None:
vision_tower = build_vision_tower(model_args)

if fsdp is not None and len(fsdp) > 0:
self.vision_tower = [vision_tower]
if fsdp is not None and len(fsdp) > 0:
self.vision_tower = [vision_tower]
else:
self.vision_tower = vision_tower
else:
self.vision_tower = vision_tower
if fsdp is not None and len(fsdp) > 0:
vision_tower = self.vision_tower[0]
else:
vision_tower = self.vision_tower
vision_tower.load_model()

self.config.use_mm_proj = True
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
self.config.mm_hidden_size = vision_tower.hidden_size
self.config.mm_vision_select_layer = mm_vision_select_layer
self.config.mm_vision_select_feature = mm_vision_select_feature

self.mm_projector = build_vision_projector(self.config)
if getattr(self, 'mm_projector', None) is None:
self.mm_projector = build_vision_projector(self.config)

if pretrain_mm_mlp_adapter is not None:
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
Expand Down
4 changes: 3 additions & 1 deletion llava/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,14 @@ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']
for name, module in model.named_modules():
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
continue
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])


if 'lm_head' in lora_module_names: # needed for 16-bit
lora_module_names.remove('lm_head')
return list(lora_module_names)
Expand Down

0 comments on commit 232302e

Please sign in to comment.