Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorized infer_beam_batch for improved performance #697

Merged
merged 1 commit into from
Aug 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 150 additions & 1 deletion manga_translator/ocr/model_48px.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def _infer(self, image: np.ndarray, textlines: List[Quadrilateral], args:
if self.use_gpu:
image_tensor = image_tensor.to(self.device)
with torch.no_grad():
ret = self.model.infer_beam_batch(image_tensor, widths, beams_k = 5, max_seq_length = 255)
ret = self.model.infer_beam_batch_tensor(image_tensor, widths, beams_k = 5, max_seq_length = 255)
for i, (pred_chars_index, prob, fg_pred, bg_pred, fg_ind_pred, bg_ind_pred) in enumerate(ret):
if prob < 0.2:
continue
Expand Down Expand Up @@ -497,16 +497,21 @@ def __init__(self, dictionary, max_len):
self.backbone = ConvNext_FeatureExtractor(48, 3, embd_dim)
self.encoders = nn.ModuleList()
self.decoders = nn.ModuleList()

for i in range(4) :
encoder = nn.TransformerEncoderLayer(embd_dim, nhead, dropout = 0, batch_first = True, norm_first = True)
encoder.self_attn = XposMultiheadAttention(embd_dim, nhead, self_attention = True)
encoder.forward = transformer_encoder_forward
self.encoders.append(encoder)
self.encoders.forward = self.encoder_forward

for i in range(5) :
decoder = nn.TransformerDecoderLayer(embd_dim, nhead, dropout = 0, batch_first = True, norm_first = True)
decoder.self_attn = XposMultiheadAttention(embd_dim, nhead, self_attention = True)
decoder.multihead_attn = XposMultiheadAttention(embd_dim, nhead, encoder_decoder_attention = True)
self.decoders.append(decoder)
self.decoders.forward = self.decoder_forward

self.embd = nn.Embedding(self.dict_size, embd_dim)
self.pred1 = nn.Sequential(nn.Linear(embd_dim, embd_dim), nn.GELU(), nn.Dropout(0.15))
self.pred = nn.Linear(embd_dim, self.dict_size)
Expand All @@ -517,6 +522,37 @@ def __init__(self, dictionary, max_len):
self.color_pred_fg_ind = nn.Linear(64, 2)
self.color_pred_bg_ind = nn.Linear(64, 2)

def encoder_forward(self, memory, encoder_mask):
for layer in self.encoders :
memory = layer(layer, src = memory, src_key_padding_mask = encoder_mask)
return memory

def decoder_forward(
self,
embd: torch.Tensor,
cached_activations: torch.Tensor, # Shape [N, L, T, E] where L=num_layers, T=sequence length, E=embedding size
memory: torch.Tensor, # Shape [N, H, W, C] (Encoder memory output)
memory_mask: torch.BoolTensor,
step: int
):

layer: nn.TransformerDecoderLayer
tgt = embd # N, 1, E for the last token embedding

for l, layer in enumerate(self.decoders):
combined_activations = cached_activations[:, l, :step, :] # N, T, E
combined_activations = torch.cat([combined_activations, tgt], dim=1) # N, T+1, E
cached_activations[:, l, step, :] = tgt.squeeze(1)

# Update cache and perform self attention
tgt = tgt + layer.self_attn(layer.norm1(tgt), layer.norm1(combined_activations), layer.norm1(combined_activations), q_offset=step)[0]
tgt = tgt + layer.multihead_attn(layer.norm2(tgt), memory, memory, key_padding_mask=memory_mask, q_offset=step)[0]
tgt = tgt + layer._ff_block(layer.norm3(tgt))

cached_activations[:, l+1, step, :] = tgt.squeeze(1) # Append the new activations

return tgt.squeeze_(1), cached_activations

def forward(self,
img: torch.FloatTensor,
char_idx: torch.LongTensor,
Expand Down Expand Up @@ -621,6 +657,119 @@ def infer_beam_batch(self, img: torch.FloatTensor, img_widths: List[int], beams_
result.append((cur_hypo.out_idx[1:], cur_hypo.prob(), fg_pred[0], bg_pred[0], fg_ind_pred[0], bg_ind_pred[0]))
return result

def infer_beam_batch_tensor(self, img: torch.FloatTensor, img_widths: List[int], beams_k: int = 5, start_tok = 1, end_tok = 2, pad_tok = 0, max_finished_hypos: int = 2, max_seq_length = 384):
N, C, H, W = img.shape
assert H == 48 and C == 3

memory = self.backbone(img)
memory = einops.rearrange(memory, 'N C 1 W -> N W C')
valid_feats_length = [(x + 3) // 4 + 2 for x in img_widths]
input_mask = torch.zeros(N, memory.size(1), dtype = torch.bool).to(img.device)

for i, l in enumerate(valid_feats_length):
input_mask[i, l:] = True
memory = self.encoders(memory, input_mask) # N, W, Dim

out_idx = torch.full((N, 1), start_tok, dtype=torch.long, device=img.device) # Shape [N, 1]
cached_activations = torch.zeros(N, len(self.decoders)+1, max_seq_length, 320, device=img.device) # [N, L, S, E]
log_probs = torch.zeros(N, 1, device=img.device) # Shape [N, 1] # N, E
idx_embedded = self.embd(out_idx[:, -1:])

decoded, cached_activations = self.decoders(idx_embedded, cached_activations, memory, input_mask, 0)
pred_char_logprob = self.pred(self.pred1(decoded)).log_softmax(-1) # N, n_chars
pred_chars_values, pred_chars_index = torch.topk(pred_char_logprob, beams_k, dim = 1) # N, k

out_idx = torch.cat([out_idx.unsqueeze(1).expand(-1, beams_k, -1), pred_chars_index.unsqueeze(-1)], dim=-1).reshape(-1, 2) # Shape [N * k, 2]
log_probs = pred_chars_values.view(-1, 1) # Shape [N * k, 1]
memory = memory.repeat_interleave(beams_k, dim=0)
input_mask = input_mask.repeat_interleave(beams_k, dim=0)
cached_activations = cached_activations.repeat_interleave(beams_k, dim=0)
batch_index = torch.arange(N).repeat_interleave(beams_k, dim=0).to(img.device)

finished_hypos = defaultdict(list)
N_remaining = N

for step in range(1, max_seq_length):
idx_embedded = self.embd(out_idx[:, -1:])
decoded, cached_activations = self.decoders(idx_embedded, cached_activations, memory, input_mask, step)
pred_char_logprob = self.pred(self.pred1(decoded)).log_softmax(-1) # Shape [N * k, dict_size]
pred_chars_values, pred_chars_index = torch.topk(pred_char_logprob, beams_k, dim=1) # [N * k, k]

finished = out_idx[:, -1] == end_tok
pred_chars_values[finished] = 0
pred_chars_index[finished] = end_tok

# Extend hypotheses
new_out_idx = out_idx.unsqueeze(1).expand(-1, beams_k, -1) # Shape [N * k, k, seq_len]
new_out_idx = torch.cat([new_out_idx, pred_chars_index.unsqueeze(-1)], dim=-1) # Shape [N * k, k, seq_len + 1]
new_out_idx = new_out_idx.view(-1, step + 2) # Reshape to [N * k^2, seq_len + 1]
new_log_probs = log_probs.unsqueeze(1).expand(-1, beams_k, -1) + pred_chars_values.unsqueeze(-1) # Shape [N * k^2, 1]
new_log_probs = new_log_probs.view(-1, 1) # [N * k^2, 1]

# Sort and select top-k hypotheses per sample
new_out_idx = new_out_idx.view(N_remaining, -1, step + 2) # [N, k^2, seq_len + 1]
new_log_probs = new_log_probs.view(N_remaining, -1) # [N, k^2]
batch_topk_log_probs, batch_topk_indices = new_log_probs.topk(beams_k, dim=1) # [N, k]

# Gather the top-k hypotheses based on log probabilities
expanded_topk_indices = batch_topk_indices.unsqueeze(-1).expand(-1, -1, new_out_idx.shape[-1]) # Shape [N, k, seq_len + 1]
out_idx = torch.gather(new_out_idx, 1, expanded_topk_indices).reshape(-1, step + 2) # [N * k, seq_len + 1]
log_probs = batch_topk_log_probs.view(-1, 1) # Reshape to [N * k, 1]

# Check for finished sequences
finished = (out_idx[:, -1] == end_tok) # Check if the last token is the end token
finished = finished.view(N_remaining, beams_k) # Reshape to [N, k]
finished_counts = finished.sum(dim=1) # Count the number of finished hypotheses per sample
finished_batch_indices = (finished_counts >= max_finished_hypos).nonzero(as_tuple=False).squeeze()

if finished_batch_indices.numel() == 0:
continue

if finished_batch_indices.dim() == 0:
finished_batch_indices = finished_batch_indices.unsqueeze(0)

for idx in finished_batch_indices:
batch_log_probs = batch_topk_log_probs[idx]
best_beam_idx = batch_log_probs.argmax()
finished_hypos[batch_index[beams_k * idx].item()] = \
out_idx[idx * beams_k + best_beam_idx], \
torch.exp(batch_log_probs[best_beam_idx]).item(), \
cached_activations[idx * beams_k + best_beam_idx]

remaining_indexs = []
for i in range(N_remaining):
if i not in finished_batch_indices:
for j in range(beams_k):
remaining_indexs.append(i * beams_k + j)

if not remaining_indexs:
break

N_remaining = int(len(remaining_indexs) / beams_k)
out_idx = out_idx.index_select(0, torch.tensor(remaining_indexs, device=img.device))
log_probs = log_probs.index_select(0, torch.tensor(remaining_indexs, device=img.device))
memory = memory.index_select(0, torch.tensor(remaining_indexs, device=img.device))
cached_activations = cached_activations.index_select(0, torch.tensor(remaining_indexs, device=img.device))
input_mask = input_mask.index_select(0, torch.tensor(remaining_indexs, device=img.device))
batch_index = batch_index.index_select(0, torch.tensor(remaining_indexs, device=img.device))

# Ensure we have the correct number of finished hypotheses for each sample
assert len(finished_hypos) == N

# Final output processing and color predictions
result = []
for i in range(N):
final_idx, prob, decoded = finished_hypos[i]
color_feats = self.color_pred1(decoded[-1].unsqueeze(0))
fg_pred, bg_pred, fg_ind_pred, bg_ind_pred = \
self.color_pred_fg(color_feats), \
self.color_pred_bg(color_feats), \
self.color_pred_fg_ind(color_feats), \
self.color_pred_bg_ind(color_feats)
result.append((final_idx[1:], prob, fg_pred[0], bg_pred[0], fg_ind_pred[0], bg_ind_pred[0]))

return result

import numpy as np

def convert_pl_model(filename: str) :
Expand Down
Loading