From 796f65ec0196d66b3cacbb9eb702adf2ff1a8bc7 Mon Sep 17 00:00:00 2001 From: OrangeCube Date: Wed, 21 Aug 2024 23:21:06 +0800 Subject: [PATCH] Vectorized infer_beam_batch for improved performance --- manga_translator/ocr/model_48px.py | 151 ++++++++++++++++++++++++++++- 1 file changed, 150 insertions(+), 1 deletion(-) diff --git a/manga_translator/ocr/model_48px.py b/manga_translator/ocr/model_48px.py index 84afd8745..fcd45eb03 100644 --- a/manga_translator/ocr/model_48px.py +++ b/manga_translator/ocr/model_48px.py @@ -99,7 +99,7 @@ async def _infer(self, image: np.ndarray, textlines: List[Quadrilateral], args: if self.use_gpu: image_tensor = image_tensor.to(self.device) with torch.no_grad(): - ret = self.model.infer_beam_batch(image_tensor, widths, beams_k = 5, max_seq_length = 255) + ret = self.model.infer_beam_batch_tensor(image_tensor, widths, beams_k = 5, max_seq_length = 255) for i, (pred_chars_index, prob, fg_pred, bg_pred, fg_ind_pred, bg_ind_pred) in enumerate(ret): if prob < 0.2: continue @@ -497,16 +497,21 @@ def __init__(self, dictionary, max_len): self.backbone = ConvNext_FeatureExtractor(48, 3, embd_dim) self.encoders = nn.ModuleList() self.decoders = nn.ModuleList() + for i in range(4) : encoder = nn.TransformerEncoderLayer(embd_dim, nhead, dropout = 0, batch_first = True, norm_first = True) encoder.self_attn = XposMultiheadAttention(embd_dim, nhead, self_attention = True) encoder.forward = transformer_encoder_forward self.encoders.append(encoder) + self.encoders.forward = self.encoder_forward + for i in range(5) : decoder = nn.TransformerDecoderLayer(embd_dim, nhead, dropout = 0, batch_first = True, norm_first = True) decoder.self_attn = XposMultiheadAttention(embd_dim, nhead, self_attention = True) decoder.multihead_attn = XposMultiheadAttention(embd_dim, nhead, encoder_decoder_attention = True) self.decoders.append(decoder) + self.decoders.forward = self.decoder_forward + self.embd = nn.Embedding(self.dict_size, embd_dim) self.pred1 = nn.Sequential(nn.Linear(embd_dim, embd_dim), nn.GELU(), nn.Dropout(0.15)) self.pred = nn.Linear(embd_dim, self.dict_size) @@ -517,6 +522,37 @@ def __init__(self, dictionary, max_len): self.color_pred_fg_ind = nn.Linear(64, 2) self.color_pred_bg_ind = nn.Linear(64, 2) + def encoder_forward(self, memory, encoder_mask): + for layer in self.encoders : + memory = layer(layer, src = memory, src_key_padding_mask = encoder_mask) + return memory + + def decoder_forward( + self, + embd: torch.Tensor, + cached_activations: torch.Tensor, # Shape [N, L, T, E] where L=num_layers, T=sequence length, E=embedding size + memory: torch.Tensor, # Shape [N, H, W, C] (Encoder memory output) + memory_mask: torch.BoolTensor, + step: int + ): + + layer: nn.TransformerDecoderLayer + tgt = embd # N, 1, E for the last token embedding + + for l, layer in enumerate(self.decoders): + combined_activations = cached_activations[:, l, :step, :] # N, T, E + combined_activations = torch.cat([combined_activations, tgt], dim=1) # N, T+1, E + cached_activations[:, l, step, :] = tgt.squeeze(1) + + # Update cache and perform self attention + tgt = tgt + layer.self_attn(layer.norm1(tgt), layer.norm1(combined_activations), layer.norm1(combined_activations), q_offset=step)[0] + tgt = tgt + layer.multihead_attn(layer.norm2(tgt), memory, memory, key_padding_mask=memory_mask, q_offset=step)[0] + tgt = tgt + layer._ff_block(layer.norm3(tgt)) + + cached_activations[:, l+1, step, :] = tgt.squeeze(1) # Append the new activations + + return tgt.squeeze_(1), cached_activations + def forward(self, img: torch.FloatTensor, char_idx: torch.LongTensor, @@ -621,6 +657,119 @@ def infer_beam_batch(self, img: torch.FloatTensor, img_widths: List[int], beams_ result.append((cur_hypo.out_idx[1:], cur_hypo.prob(), fg_pred[0], bg_pred[0], fg_ind_pred[0], bg_ind_pred[0])) return result + def infer_beam_batch_tensor(self, img: torch.FloatTensor, img_widths: List[int], beams_k: int = 5, start_tok = 1, end_tok = 2, pad_tok = 0, max_finished_hypos: int = 2, max_seq_length = 384): + N, C, H, W = img.shape + assert H == 48 and C == 3 + + memory = self.backbone(img) + memory = einops.rearrange(memory, 'N C 1 W -> N W C') + valid_feats_length = [(x + 3) // 4 + 2 for x in img_widths] + input_mask = torch.zeros(N, memory.size(1), dtype = torch.bool).to(img.device) + + for i, l in enumerate(valid_feats_length): + input_mask[i, l:] = True + memory = self.encoders(memory, input_mask) # N, W, Dim + + out_idx = torch.full((N, 1), start_tok, dtype=torch.long, device=img.device) # Shape [N, 1] + cached_activations = torch.zeros(N, len(self.decoders)+1, max_seq_length, 320, device=img.device) # [N, L, S, E] + log_probs = torch.zeros(N, 1, device=img.device) # Shape [N, 1] # N, E + idx_embedded = self.embd(out_idx[:, -1:]) + + decoded, cached_activations = self.decoders(idx_embedded, cached_activations, memory, input_mask, 0) + pred_char_logprob = self.pred(self.pred1(decoded)).log_softmax(-1) # N, n_chars + pred_chars_values, pred_chars_index = torch.topk(pred_char_logprob, beams_k, dim = 1) # N, k + + out_idx = torch.cat([out_idx.unsqueeze(1).expand(-1, beams_k, -1), pred_chars_index.unsqueeze(-1)], dim=-1).reshape(-1, 2) # Shape [N * k, 2] + log_probs = pred_chars_values.view(-1, 1) # Shape [N * k, 1] + memory = memory.repeat_interleave(beams_k, dim=0) + input_mask = input_mask.repeat_interleave(beams_k, dim=0) + cached_activations = cached_activations.repeat_interleave(beams_k, dim=0) + batch_index = torch.arange(N).repeat_interleave(beams_k, dim=0).to(img.device) + + finished_hypos = defaultdict(list) + N_remaining = N + + for step in range(1, max_seq_length): + idx_embedded = self.embd(out_idx[:, -1:]) + decoded, cached_activations = self.decoders(idx_embedded, cached_activations, memory, input_mask, step) + pred_char_logprob = self.pred(self.pred1(decoded)).log_softmax(-1) # Shape [N * k, dict_size] + pred_chars_values, pred_chars_index = torch.topk(pred_char_logprob, beams_k, dim=1) # [N * k, k] + + finished = out_idx[:, -1] == end_tok + pred_chars_values[finished] = 0 + pred_chars_index[finished] = end_tok + + # Extend hypotheses + new_out_idx = out_idx.unsqueeze(1).expand(-1, beams_k, -1) # Shape [N * k, k, seq_len] + new_out_idx = torch.cat([new_out_idx, pred_chars_index.unsqueeze(-1)], dim=-1) # Shape [N * k, k, seq_len + 1] + new_out_idx = new_out_idx.view(-1, step + 2) # Reshape to [N * k^2, seq_len + 1] + new_log_probs = log_probs.unsqueeze(1).expand(-1, beams_k, -1) + pred_chars_values.unsqueeze(-1) # Shape [N * k^2, 1] + new_log_probs = new_log_probs.view(-1, 1) # [N * k^2, 1] + + # Sort and select top-k hypotheses per sample + new_out_idx = new_out_idx.view(N_remaining, -1, step + 2) # [N, k^2, seq_len + 1] + new_log_probs = new_log_probs.view(N_remaining, -1) # [N, k^2] + batch_topk_log_probs, batch_topk_indices = new_log_probs.topk(beams_k, dim=1) # [N, k] + + # Gather the top-k hypotheses based on log probabilities + expanded_topk_indices = batch_topk_indices.unsqueeze(-1).expand(-1, -1, new_out_idx.shape[-1]) # Shape [N, k, seq_len + 1] + out_idx = torch.gather(new_out_idx, 1, expanded_topk_indices).reshape(-1, step + 2) # [N * k, seq_len + 1] + log_probs = batch_topk_log_probs.view(-1, 1) # Reshape to [N * k, 1] + + # Check for finished sequences + finished = (out_idx[:, -1] == end_tok) # Check if the last token is the end token + finished = finished.view(N_remaining, beams_k) # Reshape to [N, k] + finished_counts = finished.sum(dim=1) # Count the number of finished hypotheses per sample + finished_batch_indices = (finished_counts >= max_finished_hypos).nonzero(as_tuple=False).squeeze() + + if finished_batch_indices.numel() == 0: + continue + + if finished_batch_indices.dim() == 0: + finished_batch_indices = finished_batch_indices.unsqueeze(0) + + for idx in finished_batch_indices: + batch_log_probs = batch_topk_log_probs[idx] + best_beam_idx = batch_log_probs.argmax() + finished_hypos[batch_index[beams_k * idx].item()] = \ + out_idx[idx * beams_k + best_beam_idx], \ + torch.exp(batch_log_probs[best_beam_idx]).item(), \ + cached_activations[idx * beams_k + best_beam_idx] + + remaining_indexs = [] + for i in range(N_remaining): + if i not in finished_batch_indices: + for j in range(beams_k): + remaining_indexs.append(i * beams_k + j) + + if not remaining_indexs: + break + + N_remaining = int(len(remaining_indexs) / beams_k) + out_idx = out_idx.index_select(0, torch.tensor(remaining_indexs, device=img.device)) + log_probs = log_probs.index_select(0, torch.tensor(remaining_indexs, device=img.device)) + memory = memory.index_select(0, torch.tensor(remaining_indexs, device=img.device)) + cached_activations = cached_activations.index_select(0, torch.tensor(remaining_indexs, device=img.device)) + input_mask = input_mask.index_select(0, torch.tensor(remaining_indexs, device=img.device)) + batch_index = batch_index.index_select(0, torch.tensor(remaining_indexs, device=img.device)) + + # Ensure we have the correct number of finished hypotheses for each sample + assert len(finished_hypos) == N + + # Final output processing and color predictions + result = [] + for i in range(N): + final_idx, prob, decoded = finished_hypos[i] + color_feats = self.color_pred1(decoded[-1].unsqueeze(0)) + fg_pred, bg_pred, fg_ind_pred, bg_ind_pred = \ + self.color_pred_fg(color_feats), \ + self.color_pred_bg(color_feats), \ + self.color_pred_fg_ind(color_feats), \ + self.color_pred_bg_ind(color_feats) + result.append((final_idx[1:], prob, fg_pred[0], bg_pred[0], fg_ind_pred[0], bg_ind_pred[0])) + + return result + import numpy as np def convert_pl_model(filename: str) :