diff --git a/maga_transformer/async_decoder_engine/batch_query.py b/maga_transformer/async_decoder_engine/batch_query.py index 7f29dd26..e10a9399 100644 --- a/maga_transformer/async_decoder_engine/batch_query.py +++ b/maga_transformer/async_decoder_engine/batch_query.py @@ -52,8 +52,8 @@ def __init__(self, gen_num_per_circle: int, nccl_op: Any) -> None: self.context_lengths_list: List[int] = [] self.record_index_prob: Optional[torch.Tensor] = None self.lora_ids: List[int] = [] + self.calculate_loss: List[int] = [] self._ptuning_info = PtuningInfo() - self.model_output = ModelOutput() def __str__(self): @@ -79,6 +79,7 @@ def deepcopy(self) -> 'BatchQuery': new_batch_query.reuse_lengths_list = copy.deepcopy(self.reuse_lengths_list) new_batch_query.context_lengths_list = copy.deepcopy(self.context_lengths_list) new_batch_query.lora_ids = copy.deepcopy(self.lora_ids) + new_batch_query.calculate_loss = copy.deepcopy(self.calculate_loss) new_batch_query._ptuning_info = self._ptuning_info new_batch_query.model_output.update_length = self.model_output.update_length return new_batch_query @@ -110,6 +111,7 @@ def tp_sync(self): output_token_ids = to_cuda(self.output_token_ids) cache_block_indice = to_cuda(self.cache_block_indice) lora_ids_tensor = to_cuda(torch.IntTensor(self.lora_ids)) + calculate_loss_tensor = to_cuda(torch.IntTensor(self.calculate_loss)) else: self.generate_batch_size = int(shape_hints[1]) self.context_batch_size = int(shape_hints[2]) @@ -124,9 +126,11 @@ def tp_sync(self): reuse_lengths_tensor = torch.zeros((self.decoder_batch_size), dtype=torch.int32, device="cuda:0") context_lengths_tensor = torch.zeros((self.decoder_batch_size), dtype=torch.int32, device="cuda:0") lora_ids_tensor = torch.zeros((max(1, self.total_batch_size)), dtype=torch.int32, device="cuda:0") + calculate_loss_tensor = torch.zeros((self.context_batch_size), dtype=torch.int32, device="cuda:0") self.nccl_op_.broadcast_tp([ cache_block_indice, output_token_ids, seq_lengths_tensor, - reuse_lengths_tensor, context_lengths_tensor, lora_ids_tensor + reuse_lengths_tensor, context_lengths_tensor, lora_ids_tensor, + calculate_loss_tensor ]) if g_parallel_info.tp_rank > 0: self.cache_block_indice = to_cpu(cache_block_indice) @@ -135,6 +139,7 @@ def tp_sync(self): self.reuse_lengths_list = to_cpu(reuse_lengths_tensor).numpy().tolist() self.context_lengths_list = to_cpu(context_lengths_tensor).numpy().tolist() self.lora_ids = to_cpu(lora_ids_tensor).numpy().tolist() + self.calculate_loss = to_cpu(calculate_loss_tensor).numpy().tolist() @property def max_context_length(self): @@ -266,6 +271,7 @@ def generate_model_input(self): self.cache_block_indice = torch.IntTensor(cache_block_indice) self.merge_generate_config = GenerateConfig.merge_generate_config(self.generate_configs) self.lora_ids = lora_ids + self.calculate_loss = [c.calculate_loss for c in self.generate_configs[self.generate_batch_size:]] self.check() def update_all_errors(self, err: str): diff --git a/maga_transformer/async_decoder_engine/normal_model_executor.py b/maga_transformer/async_decoder_engine/normal_model_executor.py index 514252c5..25bcd12b 100644 --- a/maga_transformer/async_decoder_engine/normal_model_executor.py +++ b/maga_transformer/async_decoder_engine/normal_model_executor.py @@ -52,9 +52,9 @@ def _to_cuda_tensor(t: Optional[List[Any]], dtype: torch.dtype=torch.int32): def process(self, batch_query: BatchQuery) -> None: all_hidden_states = self._process(batch_query) - hidden_states = self._unpack_hidden_states(batch_query, all_hidden_states) - self._calculate_loss(batch_query, all_hidden_states) + hidden_states = self._select_last_hidden_states(batch_query, all_hidden_states) logits = self._post_transformer_nn(hidden_states) + self._calculate_loss(batch_query, all_hidden_states) if g_parallel_info.tp_size > 1 and g_parallel_info.tp_rank > 0: return with torch.cuda.nvtx.range('post_process'): @@ -66,7 +66,7 @@ def create_config_json(self): } return config_json - def _unpack_hidden_states(self, batch_query: BatchQuery, hidden_states: torch.Tensor): + def _select_last_hidden_states(self, batch_query: BatchQuery, hidden_states: torch.Tensor): index_list = list(range(0, batch_query.generate_batch_size * batch_query.num_beams)) offset = batch_query.generate_batch_size * batch_query.num_beams - 1 for i in range(0, batch_query.context_batch_size): @@ -74,6 +74,12 @@ def _unpack_hidden_states(self, batch_query: BatchQuery, hidden_states: torch.Te index_list.append(offset) return hidden_states.index_select(0, torch.tensor(index_list, device="cuda:0")) + def _select_context_hidden_states(self, batch_query: BatchQuery, hidden_states: torch.Tensor, idx): + offset = batch_query.generate_batch_size * batch_query.num_beams + for i in range(idx): + offset += batch_query.context_query_context_lengths_list[i] + return hidden_states[offset:offset + batch_query.context_query_context_lengths_list[idx],...] + def _process(self, batch_query: BatchQuery) -> torch.Tensor: with torch.cuda.nvtx.range('pre_process'): input_embeds, attention_mask, position_ids = self._pre_process(batch_query) @@ -251,26 +257,23 @@ def _reconstruct_sampler(self, batch_query: BatchQuery) -> None: self.model_ops.sampler.config = new_config self.model_ops.sampler.setup(SamplerSetupParams(batch_query.total_batch_size, self.model_ops.config.special_tokens.eos_token_id, 0, None)) - def _calculate_loss(self, - batch_query: BatchQuery, all_hidden_states: torch.Tensor): - for stream in batch_query.streams: - if stream.generate_config.calculate_loss and query.loss == None: - all_logits = self._post_transformer_nn(all_hidden_states) - break - - start_idx = 0 - for i, stream in enumerate(batch_query.streams): - if not stream.generate_config.calculate_loss: + def _calculate_loss(self, batch_query: BatchQuery, all_hidden_states: torch.Tensor): + for context_idx, calculate_loss in enumerate(batch_query.calculate_loss): + if not calculate_loss: continue - if query.loss != None: + hidden_states = self._select_context_hidden_states( + batch_query, all_hidden_states, context_idx) + logits = self._post_transformer_nn(hidden_states) + if g_parallel_info.tp_size > 1 and g_parallel_info.tp_rank > 0: continue - shift_labels = stream.complete_token_ids[0, 1:stream.context_length].type(torch.int64).contiguous() - shift_logits = all_logits[start_idx : start_idx + stream.context_length - 1, ].contiguous() - start_idx += stream.context_length + stream = batch_query.context_streams[context_idx] + shift_labels = stream.complete_token_ids[0, 1:stream.input_length].type(torch.int64) + shift_logits = logits[:stream.input_length - 1, ] loss_fct = torch.nn.CrossEntropyLoss(reduction="none") loss = loss_fct(shift_logits.to("cuda:0"), shift_labels.to("cuda:0")) + if stream.generate_config.calculate_loss == 1: - loss_mean = loss.sum(dim=0) / (stream.context_length - 1) + loss_mean = loss.sum(dim=0) / (stream.input_length - 1) loss_mean = loss_mean.exp() stream.set_loss(loss_mean) elif stream.generate_config.calculate_loss == 2: diff --git a/maga_transformer/server/inference_worker.py b/maga_transformer/server/inference_worker.py index 189980b4..9fad98ea 100644 --- a/maga_transformer/server/inference_worker.py +++ b/maga_transformer/server/inference_worker.py @@ -66,11 +66,12 @@ def _format_response(self, gen_responses: GenerateResponse, "finished": finished, "aux_info": aux_info.model_dump(mode='json'), } - if return_hidden_states: + # 判断 None,可能有 batch 还没算出来 + if return_hidden_states and hidden_states is not None: response["hidden_states"] = hidden_states.tolist() - if calculate_loss: + if calculate_loss and loss is not None: response['loss'] = loss.tolist() - if return_logits: + if return_logits and logits is not None: response['logits'] = logits.tolist() return response