Skip to content

Commit

Permalink
分割記号の変更とか
Browse files Browse the repository at this point in the history
  • Loading branch information
laksjdjf authored Mar 21, 2024
1 parent a7e764d commit 65c46f4
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions modules/tleco/tleco_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ def __init__(
text_model,
path, # 置き換え対象のテキストが書かれたファイルのパス
batch_size=1,
pad_tokens=None, # パディングトークンの文字列
pad_tokens=None, # パディングトークンの文字列リスト
selected_tags=None, # 追加するタグのリストが書かれたファイルのパス
drop_rate=0.0, # 置き換え対象をドロップする確率(正則化)
shuffle=True,
prefix="", # 先頭文字列
prefix=None, # 先頭文字列
repeat=1, # データセットの反復回数
max_add_length=16, # 追加するタグの最大数
position_rate=0.5, # タグの位置 (先頭 or 末尾)
position_rate=0.5, # タグの位置を先頭にする確率(1-rで末尾)
):
self.batch_size = batch_size
self.shuffle = shuffle
Expand All @@ -25,7 +25,7 @@ def __init__(

with open(path, 'r') as f:
self.texts = f.read().splitlines()
self.text_pairs = [[tag.strip() for tag in text.split(",")] for text in self.texts]
self.text_pairs = [[tag.strip() for tag in text.split("|")] for text in self.texts]

if pad_tokens is not None:
for pad_token in pad_tokens:
Expand Down Expand Up @@ -84,18 +84,18 @@ def __getitem__(self, idx):
elif source_length < target_length:
source += " " + " ".join(padding)

last_source = self.prefix
last_target = self.prefix
last_source = self.prefix + ", " if self.prefix else ""
last_target = self.prefix + ", " if self.prefix else ""
if random.random() > self.drop_rate:
if random.random() > self.position_rate:
last_source += ", " + source + ", " + add_tags
last_target += ", " + target + ", " + add_tags
last_source += source + ", " + add_tags
last_target += target + ", " + add_tags
else:
last_source += ", " + add_tags + ", " + source
last_target += ", " + add_tags + ", " + target
last_source += add_tags + ", " + source
last_target += add_tags + ", " + target
else:
last_source += ", " + add_tags
last_target += ", " + add_tags
last_source += add_tags
last_target += add_tags

sources.append(last_source)
targets.append(last_target)
Expand All @@ -105,4 +105,4 @@ def __getitem__(self, idx):
"target": targets
}



0 comments on commit 65c46f4

Please sign in to comment.