diff --git a/llava/mm_utils.py b/llava/mm_utils.py index 2662b91db..23fdac956 100644 --- a/llava/mm_utils.py +++ b/llava/mm_utils.py @@ -77,23 +77,26 @@ class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords self.keyword_ids = [] + self.max_keyword_len = 0 for keyword in keywords: cur_keyword_ids = tokenizer(keyword).input_ids if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: cur_keyword_ids = cur_keyword_ids[1:] + if len(cur_keyword_ids) > self.max_keyword_len: + self.max_keyword_len = len(cur_keyword_ids) self.keyword_ids.append(torch.tensor(cur_keyword_ids)) self.tokenizer = tokenizer self.start_len = input_ids.shape[1] def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO - offset = min(output_ids.shape[1] - self.start_len, 3) + offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] for keyword_id in self.keyword_ids: - if output_ids[0, -keyword_id.shape[0]:] == keyword_id: + if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all(): return True outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] for keyword in self.keywords: if keyword in outputs: return True - return False + return False \ No newline at end of file