Skip to content

Commit

Permalink
refactor(model_cache): factor out load_ckpt
Browse files Browse the repository at this point in the history
  • Loading branch information
keturn committed Nov 9, 2022
1 parent a267b45 commit 9f5e496
Showing 1 changed file with 39 additions and 24 deletions.
63 changes: 39 additions & 24 deletions ldm/invoke/model_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def get_model(self, model_name:str):
'hash': hash
}

def default_model(self) -> str:
def default_model(self) -> str | None:
'''
Returns the name of the default model, or None
if none is defined.
Expand Down Expand Up @@ -191,13 +191,6 @@ def _load_model(self, model_name:str):
return None

mconfig = self.config[model_name]
config = mconfig.config
weights = mconfig.weights
vae = mconfig.get('vae',None)
width = mconfig.width
height = mconfig.height

print(f'>> Loading {model_name} from {weights}')

# for usage statistics
if self._has_cuda():
Expand All @@ -207,15 +200,44 @@ def _load_model(self, model_name:str):
tic = time.time()

# this does the work
c = OmegaConf.load(config)
with open(weights,'rb') as f:
model_format = mconfig.get('format', 'ckpt')
if model_format == 'ckpt':
weights = mconfig.weights
print(f'>> Loading {model_name} from {weights}')
model, width, height, model_hash = self._load_ckpt_model(mconfig)
elif model_format == 'diffusers':
model, width, height, model_hash = self._load_diffusers_model(mconfig)
else:
raise NotImplementedError(f"Unknown model format {model_name}: {model_format}")

# usage statistics
toc = time.time()
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
if self._has_cuda():
print(
'>> Max VRAM used to load the model:',
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
'\n>> Current VRAM usage:'
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
)
return model, width, height, model_hash

def _load_ckpt_model(self, mconfig):
config = mconfig.config
weights = mconfig.weights
vae = mconfig.get('vae', None)
width = mconfig.width
height = mconfig.height

c = OmegaConf.load(config)
with open(weights, 'rb') as f:
weight_bytes = f.read()
model_hash = self._cached_sha256(weights,weight_bytes)
model_hash = self._cached_sha256(weights, weight_bytes)
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
del weight_bytes
sd = pl_sd['state_dict']
sd = pl_sd['state_dict']
model = instantiate_from_config(c.model)
m, u = model.load_state_dict(sd, strict=False)
m, u = model.load_state_dict(sd, strict=False)

if self.precision == 'float16':
print(' | Using faster float16 precision')
Expand Down Expand Up @@ -243,18 +265,11 @@ def _load_model(self, model_name:str):
if isinstance(m, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
m._orig_padding_mode = m.padding_mode

# usage statistics
toc = time.time()
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
if self._has_cuda():
print(
'>> Max VRAM used to load the model:',
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
'\n>> Current VRAM usage:'
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
)
return model, width, height, model_hash


def _load_diffusers_model(self, mconfig):
raise NotImplementedError() # return pipeline, width, height, model_hash

def offload_model(self, model_name:str):
'''
Offload the indicated model to CPU. Will call
Expand Down

0 comments on commit 9f5e496

Please sign in to comment.