-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathQnA.py
46 lines (38 loc) · 1.72 KB
/
QnA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import transformers
from transformers import AutoModelWithLMHead, PreTrainedTokenizerFast
from fastai.text.all import *
import fastai
import re
import pandas as pd
class Chatbot:
def __init__(self):
self.tokenizer = PreTrainedTokenizerFast.from_pretrained("jihae/kogpt2news",
bos_token='</s>', eos_token='</s>', unk_token='<unk>',
pad_token='<pad>', mask_token='<mask>')
self.model = AutoModelWithLMHead.from_pretrained("jihae/kogpt2news")
# def learn_tokenizer(self):
# #learn.model.save_pretrained("jihae/kogpt2news")
# tokenizer = PreTrainedTokenizerFast.from_pretrained("jihae/kogpt2news",
# bos_token='</s>', eos_token='</s>', unk_token='<unk>',
# pad_token='<pad>', mask_token='<mask>')
# return tokenizer
# def learn_model(self):
# model = AutoModelWithLMHead.from_pretrained("jihae/kogpt2news")
# return model
def get_answer(self, prompt):
input_ids = self.tokenizer.encode(prompt)
gen_ids = self.model.generate(torch.tensor([input_ids]),
max_length=128,
repetition_penalty=4.0,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
bos_token_id=self.tokenizer.bos_token_id,
use_cache=True
)
generated = self.tokenizer.decode(gen_ids[0,:].tolist())
end = generated.rfind('.')
answer = generated[:end+1]
return answer
# ch = Chatbot()
# print(ch.get_answer('개인연금은 중도인출이 가능한가요?'))