diff --git a/llava/model/llava_arch.py b/llava/model/llava_arch.py index fd538c937..58b2b37df 100644 --- a/llava/model/llava_arch.py +++ b/llava/model/llava_arch.py @@ -47,12 +47,19 @@ def initialize_vision_modules(self, model_args, fsdp=None): self.config.mm_vision_tower = vision_tower - vision_tower = build_vision_tower(model_args) + if self.get_vision_tower() is None: + vision_tower = build_vision_tower(model_args) - if fsdp is not None and len(fsdp) > 0: - self.vision_tower = [vision_tower] + if fsdp is not None and len(fsdp) > 0: + self.vision_tower = [vision_tower] + else: + self.vision_tower = vision_tower else: - self.vision_tower = vision_tower + if fsdp is not None and len(fsdp) > 0: + vision_tower = self.vision_tower[0] + else: + vision_tower = self.vision_tower + vision_tower.load_model() self.config.use_mm_proj = True self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') @@ -60,7 +67,8 @@ def initialize_vision_modules(self, model_args, fsdp=None): self.config.mm_vision_select_layer = mm_vision_select_layer self.config.mm_vision_select_feature = mm_vision_select_feature - self.mm_projector = build_vision_projector(self.config) + if getattr(self, 'mm_projector', None) is None: + self.mm_projector = build_vision_projector(self.config) if pretrain_mm_mlp_adapter is not None: mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') diff --git a/llava/train/train.py b/llava/train/train.py index 0d198f7e3..bfffca705 100644 --- a/llava/train/train.py +++ b/llava/train/train.py @@ -163,12 +163,14 @@ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): def find_all_linear_names(model): cls = torch.nn.Linear lora_module_names = set() + multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler'] for name, module in model.named_modules(): + if any(mm_keyword in name for mm_keyword in multimodal_keywords): + continue if isinstance(module, cls): names = name.split('.') lora_module_names.add(names[0] if len(names) == 1 else names[-1]) - if 'lm_head' in lora_module_names: # needed for 16-bit lora_module_names.remove('lm_head') return list(lora_module_names)