-
Notifications
You must be signed in to change notification settings - Fork 0
/
text_generation.py
36 lines (29 loc) · 1.37 KB
/
text_generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
def generate_text(prompt, model_name='gpt2', max_length=200, num_return_sequences=1):
# トークナイザーとモデルのロード
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
# モデルを評価モードに設定
model.eval()
# プロンプトをトークン化
inputs = tokenizer.encode(prompt, return_tensors='pt')
# テキストの生成
with torch.no_grad():
outputs = model.generate(
inputs,
max_length=max_length,
num_return_sequences=num_return_sequences,
pad_token_id=tokenizer.eos_token_id,
do_sample=True, # サンプリングを有効にすることで多様な生成が可能
top_k=50, # トップKサンプリングを有効に
top_p=0.95 # トップPサンプリングを有効に
)
# 生成されたテキストのデコード
generated_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
return generated_texts
if __name__ == "__main__":
prompt = "Once upon a time"
generated_texts = generate_text(prompt, max_length=100, num_return_sequences=1)
for i, text in enumerate(generated_texts):
print(f"Generated Text {i + 1}:\n{text}\n")