Skip to content

Commit

Permalink
fix: calculate_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
rtp-llm committed Mar 1, 2024
1 parent 309f0fc commit 688b31e
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 23 deletions.
10 changes: 8 additions & 2 deletions maga_transformer/async_decoder_engine/batch_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def __init__(self, gen_num_per_circle: int, nccl_op: Any) -> None:
self.context_lengths_list: List[int] = []
self.record_index_prob: Optional[torch.Tensor] = None
self.lora_ids: List[int] = []
self.calculate_loss: List[int] = []
self._ptuning_info = PtuningInfo()

self.model_output = ModelOutput()

def __str__(self):
Expand All @@ -79,6 +79,7 @@ def deepcopy(self) -> 'BatchQuery':
new_batch_query.reuse_lengths_list = copy.deepcopy(self.reuse_lengths_list)
new_batch_query.context_lengths_list = copy.deepcopy(self.context_lengths_list)
new_batch_query.lora_ids = copy.deepcopy(self.lora_ids)
new_batch_query.calculate_loss = copy.deepcopy(self.calculate_loss)
new_batch_query._ptuning_info = self._ptuning_info
new_batch_query.model_output.update_length = self.model_output.update_length
return new_batch_query
Expand Down Expand Up @@ -110,6 +111,7 @@ def tp_sync(self):
output_token_ids = to_cuda(self.output_token_ids)
cache_block_indice = to_cuda(self.cache_block_indice)
lora_ids_tensor = to_cuda(torch.IntTensor(self.lora_ids))
calculate_loss_tensor = to_cuda(torch.IntTensor(self.calculate_loss))
else:
self.generate_batch_size = int(shape_hints[1])
self.context_batch_size = int(shape_hints[2])
Expand All @@ -124,9 +126,11 @@ def tp_sync(self):
reuse_lengths_tensor = torch.zeros((self.decoder_batch_size), dtype=torch.int32, device="cuda:0")
context_lengths_tensor = torch.zeros((self.decoder_batch_size), dtype=torch.int32, device="cuda:0")
lora_ids_tensor = torch.zeros((max(1, self.total_batch_size)), dtype=torch.int32, device="cuda:0")
calculate_loss_tensor = torch.zeros((self.context_batch_size), dtype=torch.int32, device="cuda:0")
self.nccl_op_.broadcast_tp([
cache_block_indice, output_token_ids, seq_lengths_tensor,
reuse_lengths_tensor, context_lengths_tensor, lora_ids_tensor
reuse_lengths_tensor, context_lengths_tensor, lora_ids_tensor,
calculate_loss_tensor
])
if g_parallel_info.tp_rank > 0:
self.cache_block_indice = to_cpu(cache_block_indice)
Expand All @@ -135,6 +139,7 @@ def tp_sync(self):
self.reuse_lengths_list = to_cpu(reuse_lengths_tensor).numpy().tolist()
self.context_lengths_list = to_cpu(context_lengths_tensor).numpy().tolist()
self.lora_ids = to_cpu(lora_ids_tensor).numpy().tolist()
self.calculate_loss = to_cpu(calculate_loss_tensor).numpy().tolist()

@property
def max_context_length(self):
Expand Down Expand Up @@ -266,6 +271,7 @@ def generate_model_input(self):
self.cache_block_indice = torch.IntTensor(cache_block_indice)
self.merge_generate_config = GenerateConfig.merge_generate_config(self.generate_configs)
self.lora_ids = lora_ids
self.calculate_loss = [c.calculate_loss for c in self.generate_configs[self.generate_batch_size:]]
self.check()

def update_all_errors(self, err: str):
Expand Down
39 changes: 21 additions & 18 deletions maga_transformer/async_decoder_engine/normal_model_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def _to_cuda_tensor(t: Optional[List[Any]], dtype: torch.dtype=torch.int32):

def process(self, batch_query: BatchQuery) -> None:
all_hidden_states = self._process(batch_query)
hidden_states = self._unpack_hidden_states(batch_query, all_hidden_states)
self._calculate_loss(batch_query, all_hidden_states)
hidden_states = self._select_last_hidden_states(batch_query, all_hidden_states)
logits = self._post_transformer_nn(hidden_states)
self._calculate_loss(batch_query, all_hidden_states)
if g_parallel_info.tp_size > 1 and g_parallel_info.tp_rank > 0:
return
with torch.cuda.nvtx.range('post_process'):
Expand All @@ -66,14 +66,20 @@ def create_config_json(self):
}
return config_json

def _unpack_hidden_states(self, batch_query: BatchQuery, hidden_states: torch.Tensor):
def _select_last_hidden_states(self, batch_query: BatchQuery, hidden_states: torch.Tensor):
index_list = list(range(0, batch_query.generate_batch_size * batch_query.num_beams))
offset = batch_query.generate_batch_size * batch_query.num_beams - 1
for i in range(0, batch_query.context_batch_size):
offset = offset + batch_query.context_query_context_lengths_list[i]
index_list.append(offset)
return hidden_states.index_select(0, torch.tensor(index_list, device="cuda:0"))

def _select_context_hidden_states(self, batch_query: BatchQuery, hidden_states: torch.Tensor, idx):
offset = batch_query.generate_batch_size * batch_query.num_beams
for i in range(idx):
offset += batch_query.context_query_context_lengths_list[i]
return hidden_states[offset:offset + batch_query.context_query_context_lengths_list[idx],...]

def _process(self, batch_query: BatchQuery) -> torch.Tensor:
with torch.cuda.nvtx.range('pre_process'):
input_embeds, attention_mask, position_ids = self._pre_process(batch_query)
Expand Down Expand Up @@ -251,26 +257,23 @@ def _reconstruct_sampler(self, batch_query: BatchQuery) -> None:
self.model_ops.sampler.config = new_config
self.model_ops.sampler.setup(SamplerSetupParams(batch_query.total_batch_size, self.model_ops.config.special_tokens.eos_token_id, 0, None))

def _calculate_loss(self,
batch_query: BatchQuery, all_hidden_states: torch.Tensor):
for stream in batch_query.streams:
if stream.generate_config.calculate_loss and query.loss == None:
all_logits = self._post_transformer_nn(all_hidden_states)
break

start_idx = 0
for i, stream in enumerate(batch_query.streams):
if not stream.generate_config.calculate_loss:
def _calculate_loss(self, batch_query: BatchQuery, all_hidden_states: torch.Tensor):
for context_idx, calculate_loss in enumerate(batch_query.calculate_loss):
if not calculate_loss:
continue
if query.loss != None:
hidden_states = self._select_context_hidden_states(
batch_query, all_hidden_states, context_idx)
logits = self._post_transformer_nn(hidden_states)
if g_parallel_info.tp_size > 1 and g_parallel_info.tp_rank > 0:
continue
shift_labels = stream.complete_token_ids[0, 1:stream.context_length].type(torch.int64).contiguous()
shift_logits = all_logits[start_idx : start_idx + stream.context_length - 1, ].contiguous()
start_idx += stream.context_length
stream = batch_query.context_streams[context_idx]
shift_labels = stream.complete_token_ids[0, 1:stream.input_length].type(torch.int64)
shift_logits = logits[:stream.input_length - 1, ]
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
loss = loss_fct(shift_logits.to("cuda:0"), shift_labels.to("cuda:0"))

if stream.generate_config.calculate_loss == 1:
loss_mean = loss.sum(dim=0) / (stream.context_length - 1)
loss_mean = loss.sum(dim=0) / (stream.input_length - 1)
loss_mean = loss_mean.exp()
stream.set_loss(loss_mean)
elif stream.generate_config.calculate_loss == 2:
Expand Down
7 changes: 4 additions & 3 deletions maga_transformer/server/inference_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,12 @@ def _format_response(self, gen_responses: GenerateResponse,
"finished": finished,
"aux_info": aux_info.model_dump(mode='json'),
}
if return_hidden_states:
# 判断 None,可能有 batch 还没算出来
if return_hidden_states and hidden_states is not None:
response["hidden_states"] = hidden_states.tolist()
if calculate_loss:
if calculate_loss and loss is not None:
response['loss'] = loss.tolist()
if return_logits:
if return_logits and logits is not None:
response['logits'] = logits.tolist()

return response
Expand Down

0 comments on commit 688b31e

Please sign in to comment.