diff --git a/modules/tleco/tleco_dataset.py b/modules/tleco/tleco_dataset.py index 60401ef..0e275e0 100644 --- a/modules/tleco/tleco_dataset.py +++ b/modules/tleco/tleco_dataset.py @@ -7,14 +7,14 @@ def __init__( text_model, path, # 置き換え対象のテキストが書かれたファイルのパス batch_size=1, - pad_tokens=None, # パディングトークンの文字列 + pad_tokens=None, # パディングトークンの文字列リスト selected_tags=None, # 追加するタグのリストが書かれたファイルのパス drop_rate=0.0, # 置き換え対象をドロップする確率(正則化) shuffle=True, - prefix="", # 先頭文字列 + prefix=None, # 先頭文字列 repeat=1, # データセットの反復回数 max_add_length=16, # 追加するタグの最大数 - position_rate=0.5, # タグの位置 (先頭 or 末尾) + position_rate=0.5, # タグの位置を先頭にする確率(1-rで末尾) ): self.batch_size = batch_size self.shuffle = shuffle @@ -25,7 +25,7 @@ def __init__( with open(path, 'r') as f: self.texts = f.read().splitlines() - self.text_pairs = [[tag.strip() for tag in text.split(",")] for text in self.texts] + self.text_pairs = [[tag.strip() for tag in text.split("|")] for text in self.texts] if pad_tokens is not None: for pad_token in pad_tokens: @@ -84,18 +84,18 @@ def __getitem__(self, idx): elif source_length < target_length: source += " " + " ".join(padding) - last_source = self.prefix - last_target = self.prefix + last_source = self.prefix + ", " if self.prefix else "" + last_target = self.prefix + ", " if self.prefix else "" if random.random() > self.drop_rate: if random.random() > self.position_rate: - last_source += ", " + source + ", " + add_tags - last_target += ", " + target + ", " + add_tags + last_source += source + ", " + add_tags + last_target += target + ", " + add_tags else: - last_source += ", " + add_tags + ", " + source - last_target += ", " + add_tags + ", " + target + last_source += add_tags + ", " + source + last_target += add_tags + ", " + target else: - last_source += ", " + add_tags - last_target += ", " + add_tags + last_source += add_tags + last_target += add_tags sources.append(last_source) targets.append(last_target) @@ -105,4 +105,4 @@ def __getitem__(self, idx): "target": targets } - \ No newline at end of file +