diff --git a/language/gpt-j/quantization/autoscale/model_dict.py b/language/gpt-j/quantization/autoscale/model_dict.py index 196a17269..e3e021682 100644 --- a/language/gpt-j/quantization/autoscale/model_dict.py +++ b/language/gpt-j/quantization/autoscale/model_dict.py @@ -7,7 +7,6 @@ GPTJForCausalLM_dict = { transformers.models.gptj.modeling_gptj.GPTJForCausalLM : transformers.models.gptj.modeling_gptj, furiosa_llm_models.gptj.huggingface.GPTJForCausalLM : furiosa_llm_models.gptj.huggingface, - furiosa_llm_models.gptj.paged_attention_concat.GPTJForCausalLM : furiosa_llm_models.gptj.paged_attention_concat, furiosa_llm_models.gptj.huggingface_rope.GPTJForCausalLM: furiosa_llm_models.gptj.huggingface_rope, furiosa_llm_models.gptj.paged_attention_concat_rope.GPTJForCausalLM: furiosa_llm_models.gptj.paged_attention_concat_rope, furiosa_llm_models.gptj.preallocated_concat_rope.GPTJForCausalLM: furiosa_llm_models.gptj.preallocated_concat_rope, diff --git a/language/gpt-j/quantization/calibration_utils/paged_attention_utils.py b/language/gpt-j/quantization/calibration_utils/paged_attention_utils.py index 78f98d170..0aa2a968b 100644 --- a/language/gpt-j/quantization/calibration_utils/paged_attention_utils.py +++ b/language/gpt-j/quantization/calibration_utils/paged_attention_utils.py @@ -31,10 +31,6 @@ def update_input_metadata(updated_attention_mask: List[List[int]], block_indices active_key_block_indices.append([]) active_value_block_indices.append([]) - # last_valid_key_block_idx = None - # last_valid_value_block_idx = None - # last_valid_token_idx = None - for block in split_blocks: # x x 1 => then block is full # 1 x x => block is not full @@ -62,17 +58,6 @@ def update_input_metadata(updated_attention_mask: List[List[int]], block_indices active_key_block_indices[batch_idx].append(new_key_block_idx) active_value_block_indices[batch_idx].append(new_value_block_idx) - # last_valid_key_block_idx = new_key_block_idx - # last_valid_value_block_idx = new_value_block_idx - # last_valid_token_idx = last_idx - - # self.valid_block_meta.append( - # ( - # (last_valid_key_block_idx, last_valid_token_idx), - # (last_valid_value_block_idx, last_valid_token_idx), - # ) - # ) - new_key_locations.append(torch.unsqueeze(torch.cat(new_key_location), 0)) new_value_locations.append(torch.unsqueeze(torch.cat(new_value_location), 0)) @@ -100,7 +85,7 @@ def make_calib_dataloader_for_paged_attention(calib_dataset_path, batch_size, bu #There could be a bug associated with multi-batch calibration in mcp at the moment. assert batch_size == 1 - # batch_size = 2 + data_object = Dataset(calib_dataset_path, batch_size) data_list = [] block_indices, block_size, head, head_size = total_block_space[0][0].shape @@ -162,7 +147,9 @@ def make_calib_dataloader_for_paged_attention(calib_dataset_path, batch_size, bu return DataLoader(data_list, batch_size) - + + +