-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_sentencepiece.py
49 lines (37 loc) · 1.76 KB
/
train_sentencepiece.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import pathlib
import random
import sys
from argparse import ArgumentParser
import numpy as np
import sentencepiece
from simi.dataset import Data
from simi.segmentation import train_sentencepiece
from simi.utils import ensure_path
def parseArgs():
parser = ArgumentParser()
parser.add_argument('trainset', type=pathlib.Path,
help='Path to the quantized trainset for sentencepiece learning.')
parser.add_argument('sentencepiece_prefix', type=pathlib.Path,
help='Prefix for the sentencepiece model, last item in path should be the model\'s name')
parser.add_argument('vocab_size', type=int,
help='Sentencepiece\'s vocabulary size.')
parser.add_argument('--seed', type=int, default=290956,
help='Random seed')
parser.add_argument('--overwrite', action='store_true',
help='Overwrite existing sentencepiece model')
parser.add_argument('--max_piece_length', type=int, default=100,
help='Max length of sentence piece. Default=100')
return parser.parse_args()
def run(args):
random.seed(args.seed)
np.random.seed(args.seed)
sentencepiece.set_random_generator_seed(args.seed)
if ensure_path(args.sentencepiece_prefix.with_suffix('.model')) and not args.overwrite:
print(f'Sentencepiece model found at {args.sentencepiece_prefix}, skipping training. If you want to overwrite, rerun with --overwrite.')
sys.exit(0)
print('Loading trainset...')
trainset = Data(args.trainset)
train_sentencepiece(trainset.data, args.sentencepiece_prefix, args.vocab_size, max_piece_length=args.max_piece_length)
if __name__ == "__main__":
args = parseArgs()
run(args)