Skip to content

Commit

Permalink
Merge pull request #697 from 9173860/patch-3
Browse files Browse the repository at this point in the history
Vectorized infer_beam_batch for improved performance
  • Loading branch information
zyddnys authored Aug 24, 2024
2 parents 94a411e + 796f65e commit ead6693
Showing 1 changed file with 150 additions and 1 deletion.
151 changes: 150 additions & 1 deletion manga_translator/ocr/model_48px.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def _infer(self, image: np.ndarray, textlines: List[Quadrilateral], args:
if self.use_gpu:
image_tensor = image_tensor.to(self.device)
with torch.no_grad():
ret = self.model.infer_beam_batch(image_tensor, widths, beams_k = 5, max_seq_length = 255)
ret = self.model.infer_beam_batch_tensor(image_tensor, widths, beams_k = 5, max_seq_length = 255)
for i, (pred_chars_index, prob, fg_pred, bg_pred, fg_ind_pred, bg_ind_pred) in enumerate(ret):
if prob < 0.2:
continue
Expand Down Expand Up @@ -497,16 +497,21 @@ def __init__(self, dictionary, max_len):
self.backbone = ConvNext_FeatureExtractor(48, 3, embd_dim)
self.encoders = nn.ModuleList()
self.decoders = nn.ModuleList()

for i in range(4) :
encoder = nn.TransformerEncoderLayer(embd_dim, nhead, dropout = 0, batch_first = True, norm_first = True)
encoder.self_attn = XposMultiheadAttention(embd_dim, nhead, self_attention = True)
encoder.forward = transformer_encoder_forward
self.encoders.append(encoder)
self.encoders.forward = self.encoder_forward

for i in range(5) :
decoder = nn.TransformerDecoderLayer(embd_dim, nhead, dropout = 0, batch_first = True, norm_first = True)
decoder.self_attn = XposMultiheadAttention(embd_dim, nhead, self_attention = True)
decoder.multihead_attn = XposMultiheadAttention(embd_dim, nhead, encoder_decoder_attention = True)
self.decoders.append(decoder)
self.decoders.forward = self.decoder_forward

self.embd = nn.Embedding(self.dict_size, embd_dim)
self.pred1 = nn.Sequential(nn.Linear(embd_dim, embd_dim), nn.GELU(), nn.Dropout(0.15))
self.pred = nn.Linear(embd_dim, self.dict_size)
Expand All @@ -517,6 +522,37 @@ def __init__(self, dictionary, max_len):
self.color_pred_fg_ind = nn.Linear(64, 2)
self.color_pred_bg_ind = nn.Linear(64, 2)

def encoder_forward(self, memory, encoder_mask):
for layer in self.encoders :
memory = layer(layer, src = memory, src_key_padding_mask = encoder_mask)
return memory

def decoder_forward(
self,
embd: torch.Tensor,
cached_activations: torch.Tensor, # Shape [N, L, T, E] where L=num_layers, T=sequence length, E=embedding size
memory: torch.Tensor, # Shape [N, H, W, C] (Encoder memory output)
memory_mask: torch.BoolTensor,
step: int
):

layer: nn.TransformerDecoderLayer
tgt = embd # N, 1, E for the last token embedding

for l, layer in enumerate(self.decoders):
combined_activations = cached_activations[:, l, :step, :] # N, T, E
combined_activations = torch.cat([combined_activations, tgt], dim=1) # N, T+1, E
cached_activations[:, l, step, :] = tgt.squeeze(1)

# Update cache and perform self attention
tgt = tgt + layer.self_attn(layer.norm1(tgt), layer.norm1(combined_activations), layer.norm1(combined_activations), q_offset=step)[0]
tgt = tgt + layer.multihead_attn(layer.norm2(tgt), memory, memory, key_padding_mask=memory_mask, q_offset=step)[0]
tgt = tgt + layer._ff_block(layer.norm3(tgt))

cached_activations[:, l+1, step, :] = tgt.squeeze(1) # Append the new activations

return tgt.squeeze_(1), cached_activations

def forward(self,
img: torch.FloatTensor,
char_idx: torch.LongTensor,
Expand Down Expand Up @@ -621,6 +657,119 @@ def infer_beam_batch(self, img: torch.FloatTensor, img_widths: List[int], beams_
result.append((cur_hypo.out_idx[1:], cur_hypo.prob(), fg_pred[0], bg_pred[0], fg_ind_pred[0], bg_ind_pred[0]))
return result

def infer_beam_batch_tensor(self, img: torch.FloatTensor, img_widths: List[int], beams_k: int = 5, start_tok = 1, end_tok = 2, pad_tok = 0, max_finished_hypos: int = 2, max_seq_length = 384):
N, C, H, W = img.shape
assert H == 48 and C == 3

memory = self.backbone(img)
memory = einops.rearrange(memory, 'N C 1 W -> N W C')
valid_feats_length = [(x + 3) // 4 + 2 for x in img_widths]
input_mask = torch.zeros(N, memory.size(1), dtype = torch.bool).to(img.device)

for i, l in enumerate(valid_feats_length):
input_mask[i, l:] = True
memory = self.encoders(memory, input_mask) # N, W, Dim

out_idx = torch.full((N, 1), start_tok, dtype=torch.long, device=img.device) # Shape [N, 1]
cached_activations = torch.zeros(N, len(self.decoders)+1, max_seq_length, 320, device=img.device) # [N, L, S, E]
log_probs = torch.zeros(N, 1, device=img.device) # Shape [N, 1] # N, E
idx_embedded = self.embd(out_idx[:, -1:])

decoded, cached_activations = self.decoders(idx_embedded, cached_activations, memory, input_mask, 0)
pred_char_logprob = self.pred(self.pred1(decoded)).log_softmax(-1) # N, n_chars
pred_chars_values, pred_chars_index = torch.topk(pred_char_logprob, beams_k, dim = 1) # N, k

out_idx = torch.cat([out_idx.unsqueeze(1).expand(-1, beams_k, -1), pred_chars_index.unsqueeze(-1)], dim=-1).reshape(-1, 2) # Shape [N * k, 2]
log_probs = pred_chars_values.view(-1, 1) # Shape [N * k, 1]
memory = memory.repeat_interleave(beams_k, dim=0)
input_mask = input_mask.repeat_interleave(beams_k, dim=0)
cached_activations = cached_activations.repeat_interleave(beams_k, dim=0)
batch_index = torch.arange(N).repeat_interleave(beams_k, dim=0).to(img.device)

finished_hypos = defaultdict(list)
N_remaining = N

for step in range(1, max_seq_length):
idx_embedded = self.embd(out_idx[:, -1:])
decoded, cached_activations = self.decoders(idx_embedded, cached_activations, memory, input_mask, step)
pred_char_logprob = self.pred(self.pred1(decoded)).log_softmax(-1) # Shape [N * k, dict_size]
pred_chars_values, pred_chars_index = torch.topk(pred_char_logprob, beams_k, dim=1) # [N * k, k]

finished = out_idx[:, -1] == end_tok
pred_chars_values[finished] = 0
pred_chars_index[finished] = end_tok

# Extend hypotheses
new_out_idx = out_idx.unsqueeze(1).expand(-1, beams_k, -1) # Shape [N * k, k, seq_len]
new_out_idx = torch.cat([new_out_idx, pred_chars_index.unsqueeze(-1)], dim=-1) # Shape [N * k, k, seq_len + 1]
new_out_idx = new_out_idx.view(-1, step + 2) # Reshape to [N * k^2, seq_len + 1]
new_log_probs = log_probs.unsqueeze(1).expand(-1, beams_k, -1) + pred_chars_values.unsqueeze(-1) # Shape [N * k^2, 1]
new_log_probs = new_log_probs.view(-1, 1) # [N * k^2, 1]

# Sort and select top-k hypotheses per sample
new_out_idx = new_out_idx.view(N_remaining, -1, step + 2) # [N, k^2, seq_len + 1]
new_log_probs = new_log_probs.view(N_remaining, -1) # [N, k^2]
batch_topk_log_probs, batch_topk_indices = new_log_probs.topk(beams_k, dim=1) # [N, k]

# Gather the top-k hypotheses based on log probabilities
expanded_topk_indices = batch_topk_indices.unsqueeze(-1).expand(-1, -1, new_out_idx.shape[-1]) # Shape [N, k, seq_len + 1]
out_idx = torch.gather(new_out_idx, 1, expanded_topk_indices).reshape(-1, step + 2) # [N * k, seq_len + 1]
log_probs = batch_topk_log_probs.view(-1, 1) # Reshape to [N * k, 1]

# Check for finished sequences
finished = (out_idx[:, -1] == end_tok) # Check if the last token is the end token
finished = finished.view(N_remaining, beams_k) # Reshape to [N, k]
finished_counts = finished.sum(dim=1) # Count the number of finished hypotheses per sample
finished_batch_indices = (finished_counts >= max_finished_hypos).nonzero(as_tuple=False).squeeze()

if finished_batch_indices.numel() == 0:
continue

if finished_batch_indices.dim() == 0:
finished_batch_indices = finished_batch_indices.unsqueeze(0)

for idx in finished_batch_indices:
batch_log_probs = batch_topk_log_probs[idx]
best_beam_idx = batch_log_probs.argmax()
finished_hypos[batch_index[beams_k * idx].item()] = \
out_idx[idx * beams_k + best_beam_idx], \
torch.exp(batch_log_probs[best_beam_idx]).item(), \
cached_activations[idx * beams_k + best_beam_idx]

remaining_indexs = []
for i in range(N_remaining):
if i not in finished_batch_indices:
for j in range(beams_k):
remaining_indexs.append(i * beams_k + j)

if not remaining_indexs:
break

N_remaining = int(len(remaining_indexs) / beams_k)
out_idx = out_idx.index_select(0, torch.tensor(remaining_indexs, device=img.device))
log_probs = log_probs.index_select(0, torch.tensor(remaining_indexs, device=img.device))
memory = memory.index_select(0, torch.tensor(remaining_indexs, device=img.device))
cached_activations = cached_activations.index_select(0, torch.tensor(remaining_indexs, device=img.device))
input_mask = input_mask.index_select(0, torch.tensor(remaining_indexs, device=img.device))
batch_index = batch_index.index_select(0, torch.tensor(remaining_indexs, device=img.device))

# Ensure we have the correct number of finished hypotheses for each sample
assert len(finished_hypos) == N

# Final output processing and color predictions
result = []
for i in range(N):
final_idx, prob, decoded = finished_hypos[i]
color_feats = self.color_pred1(decoded[-1].unsqueeze(0))
fg_pred, bg_pred, fg_ind_pred, bg_ind_pred = \
self.color_pred_fg(color_feats), \
self.color_pred_bg(color_feats), \
self.color_pred_fg_ind(color_feats), \
self.color_pred_bg_ind(color_feats)
result.append((final_idx[1:], prob, fg_pred[0], bg_pred[0], fg_ind_pred[0], bg_ind_pred[0]))

return result

import numpy as np

def convert_pl_model(filename: str) :
Expand Down

0 comments on commit ead6693

Please sign in to comment.