Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tentando corrigir o erro substring not found #116

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 81 additions & 27 deletions pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@
import logging
from typing import Optional, Dict, Union

# nltk == Natural Language Toolkit
from nltk import sent_tokenize
"""sent_tokenize: Return a sentence-tokenized copy of text, using NLTK’s recommended
sentence tokenizer (currently PunktSentenceTokenizer for the specified language)"""

"""nltk.tokenize.sent_tokenize(text, language='english')
- text: text to split into sentences
- language: the model name in the Punkt corpus
"""

import torch
from transformers import(
Expand Down Expand Up @@ -33,62 +41,80 @@ def __init__(

self.qg_format = qg_format

# Set the device to GPU if available and use_cuda is True, otherwise use CPU
self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
self.model.to(self.device)

# Move the answer extraction model to the same device, if it's different from the main model
if self.ans_model is not self.model:
self.ans_model.to(self.device)

# Ensure the model is of type T5 or BART, as required for this pipeline
assert self.model.__class__.__name__ in ["T5ForConditionalGeneration", "BartForConditionalGeneration"]

# Set model type for conditional operations later
if "T5ForConditionalGeneration" in self.model.__class__.__name__:
self.model_type = "t5"
else:
self.model_type = "bart"

def __call__(self, inputs: str):
# Clean up the input text by removing extra whitespace
inputs = " ".join(inputs.split())
# Extract sentences and answers from the input text
sents, answers = self._extract_answers(inputs)
# Flatten the list of answers for easy iteration
flat_answers = list(itertools.chain(*answers))

if len(flat_answers) == 0:
return []
return []

# Prepare inputs for question generation based on the specified format
if self.qg_format == "prepend":
qg_examples = self._prepare_inputs_for_qg_from_answers_prepend(inputs, answers)
else:
qg_examples = self._prepare_inputs_for_qg_from_answers_hl(sents, answers)

# Generate questions for each example
qg_inputs = [example['source_text'] for example in qg_examples]
questions = self._generate_questions(qg_inputs)
# Combine the answers and generated questions into the output format
output = [{'answer': example['answer'], 'question': que} for example, que in zip(qg_examples, questions)]
return output

def _generate_questions(self, inputs):
# Tokenize the inputs for the model
inputs = self._tokenize(inputs, padding=True, truncation=True)

# Generate questions using the model
outs = self.model.generate(
input_ids=inputs['input_ids'].to(self.device),
attention_mask=inputs['attention_mask'].to(self.device),
max_length=32,
num_beams=4,
num_beams=4, # Use beam search for better question generation
)

# Decode the generated output to get the questions
questions = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
return questions

def _extract_answers(self, context):
# Prepare inputs for answer extraction by splitting the context into sentences
sents, inputs = self._prepare_inputs_for_ans_extraction(context)
# Tokenize the inputs
inputs = self._tokenize(inputs, padding=True, truncation=True)

# Generate answers using the answer extraction model
outs = self.ans_model.generate(
input_ids=inputs['input_ids'].to(self.device),
attention_mask=inputs['attention_mask'].to(self.device),
max_length=32,
)

# Decode the generated output to get answers and split by the separator token
dec = [self.ans_tokenizer.decode(ids, skip_special_tokens=False) for ids in outs]
answers = [item.split('<sep>') for item in dec]
# Remove the last element (which is empty) from each list of answers
answers = [i[:-1] for i in answers]

return sents, answers
Expand All @@ -100,6 +126,7 @@ def _tokenize(self,
add_special_tokens=True,
max_length=512
):
# Tokenize the inputs using the tokenizer, with options for padding and truncation
inputs = self.tokenizer.batch_encode_plus(
inputs,
max_length=max_length,
Expand All @@ -112,17 +139,20 @@ def _tokenize(self,
return inputs

def _prepare_inputs_for_ans_extraction(self, text):
# Split the input text into sentences
sents = sent_tokenize(text)

inputs = []
for i in range(len(sents)):
source_text = "extract answers:"
for j, sent in enumerate(sents):
# Highlight the current sentence
if i == j:
sent = "<hl> %s <hl>" % sent
source_text = "%s %s" % (source_text, sent)
source_text = source_text.strip()

# Add end-of-sequence token if the model type is T5
if self.model_type == "t5":
source_text = source_text + " </s>"
inputs.append(source_text)
Expand All @@ -139,11 +169,27 @@ def _prepare_inputs_for_qg_from_answers_hl(self, sents, answers):

answer_text = answer_text.strip()

ans_start_idx = sent.index(answer_text)
# Debugging: Print the sentence and answer to understand why the index might fail
logger.debug(f"Processing sentence: {sent}")
logger.debug(f"Processing answer: {answer_text}")

# inicio_adicao
if '<pad>' in answer_text:
answer_text = answer_text.replace('<pad>', '').strip()
print(f"Answer without <pad>: '{answer_text}'")
# fim_adicao

# Find the start index of the answer in the sentence
ans_start_idx = sent.find(answer_text)
if ans_start_idx == -1:
logger.warning(f"Answer '{answer_text}' not found in sentence '{sent}'")
continue

# Insert highlight tokens around the answer text
sent = f"{sent[:ans_start_idx]} <hl> {answer_text} <hl> {sent[ans_start_idx + len(answer_text): ]}"
sents_copy[i] = sent

# Combine the sentences and prepare the source text for question generation
source_text = " ".join(sents_copy)
source_text = f"generate question: {source_text}"
if self.model_type == "t5":
Expand All @@ -154,9 +200,11 @@ def _prepare_inputs_for_qg_from_answers_hl(self, sents, answers):
return inputs

def _prepare_inputs_for_qg_from_answers_prepend(self, context, answers):
# Flatten the list of answers
flat_answers = list(itertools.chain(*answers))
examples = []
for answer in flat_answers:
# Prepare the source text by prepending the answer to the context
source_text = f"answer: {answer} context: {context}"
if self.model_type == "t5":
source_text = source_text + " </s>"
Expand All @@ -171,28 +219,32 @@ def __init__(self, **kwargs):

def __call__(self, inputs: Union[Dict, str]):
if type(inputs) is str:
# do qg
# If input is a string, perform question generation
return super().__call__(inputs)
else:
# do qa
# If input is a dictionary, perform question answering
return self._extract_answer(inputs["question"], inputs["context"])

def _prepare_inputs_for_qa(self, question, context):
# Prepare the input text for the question answering task
source_text = f"question: {question} context: {context}"
if self.model_type == "t5":
source_text = source_text + " </s>"
return source_text
return source_text

def _extract_answer(self, question, context):
# Prepare inputs for extracting answers
source_text = self._prepare_inputs_for_qa(question, context)
inputs = self._tokenize([source_text], padding=False)

# Generate the answer using the model
outs = self.model.generate(
input_ids=inputs['input_ids'].to(self.device),
attention_mask=inputs['attention_mask'].to(self.device),
max_length=16,
)

# Decode the generated output to get the answer
answer = self.tokenizer.decode(outs[0], skip_special_tokens=True)
return answer

Expand All @@ -208,16 +260,19 @@ def __init__(
self.model = model
self.tokenizer = tokenizer

# Set the device to GPU if available and use_cuda is True, otherwise use CPU
self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
self.model.to(self.device)

# Ensure the model is of type T5 or BART
assert self.model.__class__.__name__ in ["T5ForConditionalGeneration", "BartForConditionalGeneration"]

if "T5ForConditionalGeneration" in self.model.__class__.__name__:
self.model_type = "t5"
else:
self.model_type = "bart"

# Set default arguments for generation
self.default_generate_kwargs = {
"max_length": 256,
"num_beams": 4,
Expand All @@ -227,35 +282,30 @@ def __init__(
}

def __call__(self, context: str, **generate_kwargs):
# Prepare inputs for end-to-end question generation
inputs = self._prepare_inputs_for_e2e_qg(context)

# TODO: when overrding default_generate_kwargs all other arguments need to be passsed
# find a better way to do this
# Override default arguments if any are provided
if not generate_kwargs:
generate_kwargs = self.default_generate_kwargs

input_length = inputs["input_ids"].shape[-1]

# max_length = generate_kwargs.get("max_length", 256)
# if input_length < max_length:
# logger.warning(
# "Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format(
# max_length, input_length
# )
# )

# Generate questions using the model
outs = self.model.generate(
input_ids=inputs['input_ids'].to(self.device),
attention_mask=inputs['attention_mask'].to(self.device),
**generate_kwargs
)

# Decode the generated output to get the questions
prediction = self.tokenizer.decode(outs[0], skip_special_tokens=True)
questions = prediction.split("<sep>")
questions = [question.strip() for question in questions[:-1]]
return questions

def _prepare_inputs_for_e2e_qg(self, context):
# Prepare the input text for end-to-end question generation
source_text = f"generate questions: {context}"
if self.model_type == "t5":
source_text = source_text + " </s>"
Expand All @@ -271,6 +321,7 @@ def _tokenize(
add_special_tokens=True,
max_length=512
):
# Tokenize the inputs using the tokenizer, with options for padding and truncation
inputs = self.tokenizer.batch_encode_plus(
inputs,
max_length=max_length,
Expand Down Expand Up @@ -305,15 +356,17 @@ def _tokenize(
}
}

# Optional (Se não especificado será None(ou highlight no caso de qg_format)).

def pipeline(
task: str,
task: str, # question-genearation, multitask-qa-qg, e2e-qg
model: Optional = None,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
qg_format: Optional[str] = "highlight",
qg_format: Optional[str] = "highlight", # highlight ou prepend
ans_model: Optional = None,
ans_tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
use_cuda: Optional[bool] = True,
**kwargs,
use_cuda: Optional[bool] = True, # Usar GPU?
**kwargs, # argumentos adicionais
):
# Retrieve the task
if task not in SUPPORTED_TASKS:
Expand All @@ -327,14 +380,15 @@ def pipeline(
model = targeted_task["default"]["model"]

# Try to infer tokenizer from model or config name (if provided as str)
# Inferir == tentar deduzir, adivinhar
if tokenizer is None:
if isinstance(model, str):
if isinstance(model, str): # objeto é instância de classe?
tokenizer = model
else:
# Impossible to guest what is the right tokenizer here
# Impossible to guess what is the right tokenizer here
raise Exception(
"Impossible to guess which tokenizer to use. "
"Please provided a PretrainedTokenizer class or a path/identifier to a pretrained tokenizer."
"Please provide a PretrainedTokenizer class or a path/identifier to a pretrained tokenizer."
)

# Instantiate tokenizer if needed
Expand All @@ -350,8 +404,8 @@ def pipeline(
model = AutoModelForSeq2SeqLM.from_pretrained(model)

if task == "question-generation":
if ans_model is None:
# load default ans model
if ans_model is None: #ans_model == modelo para extração de respostas
# Load default answer extraction model
ans_model = targeted_task["default"]["ans_model"]
ans_tokenizer = AutoTokenizer.from_pretrained(ans_model)
ans_model = AutoModelForSeq2SeqLM.from_pretrained(ans_model)
Expand All @@ -361,10 +415,10 @@ def pipeline(
if isinstance(ans_model, str):
ans_tokenizer = ans_model
else:
# Impossible to guest what is the right tokenizer here
# Impossible to guess what is the right tokenizer here
raise Exception(
"Impossible to guess which tokenizer to use. "
"Please provided a PretrainedTokenizer class or a path/identifier to a pretrained tokenizer."
"Please provide a PretrainedTokenizer class or a path/identifier to a pretrained tokenizer."
)

# Instantiate tokenizer if needed
Expand Down