Skip to content

Commit

Permalink
support use_cpu_to_save_cuda_mem_for_catcher for vlm quantization (#224)
Browse files Browse the repository at this point in the history
  • Loading branch information
helloyongyang authored Nov 23, 2024
1 parent d790837 commit 3d2155e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
14 changes: 6 additions & 8 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,11 @@ def block_forward(self, block, input_data=None):

for i in range(len(input_data)):
input_data[i] = input_data[i].to(device=next(block.parameters()).device)
keys_to_device = ['attention_mask', 'cross_attention_mask', 'cross_attention_states']
for key in keys_to_device:
if (
key in self.input['kwargs'][i]
and self.input['kwargs'][i][key] is not None
):
self.input['kwargs'][i][key] = \
self.input['kwargs'][i][key].to(device=next(block.parameters()).device)
for k in self.input['kwargs'][i]:
if torch.is_tensor(self.input['kwargs'][i][k]):
self.input['kwargs'][i][k] = self.input['kwargs'][i][k].to(device=next(block.parameters()).device) # noqa
if isinstance(self.input['kwargs'][i][k], tuple):
self.input['kwargs'][i][k] = tuple(tmp.to(device=next(block.parameters()).device) for tmp in self.input['kwargs'][i][k]) # noqa
with torch.no_grad():
out = block(input_data[i], **self.input['kwargs'][i])
if isinstance(out, tuple):
Expand Down Expand Up @@ -452,6 +449,7 @@ def run(self, block, input_feat, handles):
)
self.set_non_linear_mode('fake_quant', block, False)
self.input['data'] = self.block_forward(block)
torch.cuda.empty_cache()

def block_transform(self, block, input_feat, block_kwargs):
logger.info(f'Start transform the {self.block_idx}-th block')
Expand Down
23 changes: 13 additions & 10 deletions llmc/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self, config, device_map=None, use_cache=False):
self.model_type = self.config.model.type
self.model_path = self.config.model.path
self.tokenizer_mode = self.config.model.get('tokenizer_mode', 'fast')
self.use_cpu_to_save_cuda_mem_for_catcher = self.config.model.get('use_cpu_to_save_cuda_mem_for_catcher', False) # noqa
torch_dtype = self.config.model.torch_dtype
self.torch_dtype = torch_dtype if torch_dtype == 'auto' else eval(torch_dtype)
self.device_map = device_map
Expand Down Expand Up @@ -167,11 +168,12 @@ def collect_first_block_input(self, calib_data, padding_mask=None,
self.find_blocks(modality)
Catcher = self.get_catcher(first_block_input)

self.move_embed_to_device('cuda')
if data_type == 'img_txt':
self.vision_model = self.vision_model.to('cuda')
self.projector = self.projector.to('cuda')
self.blocks[0] = self.blocks[0].cuda()
if not self.use_cpu_to_save_cuda_mem_for_catcher:
self.move_embed_to_device('cuda')
if data_type == 'img_txt':
self.vision_model = self.vision_model.to('cuda')
self.projector = self.projector.to('cuda')
self.blocks[0] = self.blocks[0].cuda()
self.blocks[0] = Catcher(self.blocks[0])

for data in calib_data:
Expand Down Expand Up @@ -203,12 +205,13 @@ def collect_first_block_input(self, calib_data, padding_mask=None,
value=1
)
self.padding_mask = padding_mask
if data_type == 'img_txt':
self.vision_model = self.vision_model.cpu()
self.projector = self.projector.cpu()
if not self.use_cpu_to_save_cuda_mem_for_catcher:
if data_type == 'img_txt':
self.vision_model = self.vision_model.cpu()
self.projector = self.projector.cpu()
self.blocks[0] = self.blocks[0].cpu()
self.move_embed_to_device('cpu')
self.blocks[0] = self.blocks[0].module
self.blocks[0] = self.blocks[0].cpu()
self.move_embed_to_device('cpu')

def get_one_pad_setting(self, padding_side, length):
if padding_side == 'left':
Expand Down

0 comments on commit 3d2155e

Please sign in to comment.