diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index 1999973ea88..4c7297e087e 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -94,7 +94,7 @@ def get_model(self, model_name:str): 'hash': hash } - def default_model(self) -> str: + def default_model(self) -> str | None: ''' Returns the name of the default model, or None if none is defined. @@ -191,13 +191,6 @@ def _load_model(self, model_name:str): return None mconfig = self.config[model_name] - config = mconfig.config - weights = mconfig.weights - vae = mconfig.get('vae',None) - width = mconfig.width - height = mconfig.height - - print(f'>> Loading {model_name} from {weights}') # for usage statistics if self._has_cuda(): @@ -207,15 +200,44 @@ def _load_model(self, model_name:str): tic = time.time() # this does the work - c = OmegaConf.load(config) - with open(weights,'rb') as f: + model_format = mconfig.get('format', 'ckpt') + if model_format == 'ckpt': + weights = mconfig.weights + print(f'>> Loading {model_name} from {weights}') + model, width, height, model_hash = self._load_ckpt_model(mconfig) + elif model_format == 'diffusers': + model, width, height, model_hash = self._load_diffusers_model(mconfig) + else: + raise NotImplementedError(f"Unknown model format {model_name}: {model_format}") + + # usage statistics + toc = time.time() + print(f'>> Model loaded in', '%4.2fs' % (toc - tic)) + if self._has_cuda(): + print( + '>> Max VRAM used to load the model:', + '%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9), + '\n>> Current VRAM usage:' + '%4.2fG' % (torch.cuda.memory_allocated() / 1e9), + ) + return model, width, height, model_hash + + def _load_ckpt_model(self, mconfig): + config = mconfig.config + weights = mconfig.weights + vae = mconfig.get('vae', None) + width = mconfig.width + height = mconfig.height + + c = OmegaConf.load(config) + with open(weights, 'rb') as f: weight_bytes = f.read() - model_hash = self._cached_sha256(weights,weight_bytes) + model_hash = self._cached_sha256(weights, weight_bytes) pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu') del weight_bytes - sd = pl_sd['state_dict'] + sd = pl_sd['state_dict'] model = instantiate_from_config(c.model) - m, u = model.load_state_dict(sd, strict=False) + m, u = model.load_state_dict(sd, strict=False) if self.precision == 'float16': print(' | Using faster float16 precision') @@ -243,18 +265,11 @@ def _load_model(self, model_name:str): if isinstance(m, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): m._orig_padding_mode = m.padding_mode - # usage statistics - toc = time.time() - print(f'>> Model loaded in', '%4.2fs' % (toc - tic)) - if self._has_cuda(): - print( - '>> Max VRAM used to load the model:', - '%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9), - '\n>> Current VRAM usage:' - '%4.2fG' % (torch.cuda.memory_allocated() / 1e9), - ) return model, width, height, model_hash - + + def _load_diffusers_model(self, mconfig): + raise NotImplementedError() # return pipeline, width, height, model_hash + def offload_model(self, model_name:str): ''' Offload the indicated model to CPU. Will call