-
Notifications
You must be signed in to change notification settings - Fork 0
/
cli.py
73 lines (61 loc) · 2.13 KB
/
cli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import click, json
from tools.wrappers import create_and_train_text_model, list_saved_models,\
predict_with_model
@click.group()
#@click.version_option(version='1.0.0')
def my_main():
"""
CLI interface for training, saving and using models
"""
pass
@my_main.command()
def list_models():
"""
Shows a list of available (saved) models
"""
models = list_saved_models()
print('Saved Models Info:')
present(models)
@my_main.command()
@click.argument('dataset_tsv')
@click.argument('model_name')
@click.option('--epochs', default=1, help='how many epochs to train for')
@click.option('--limit', default=0, help='limit the ammount of training data')
@click.option('--delimiter', default="\t",
help='the character to use as the delim in the training file')
@click.option('--validation-size', default=10_000,
help='the size of the validation (when training)')
def train(**kwargs):
"""
Trains a model with a dataset and saves the model file(s)
"""
train_fname = kwargs['dataset_tsv']
model_name = kwargs['model_name']
epochs = kwargs['epochs']
limit = kwargs['limit']
val_size = kwargs['validation_size']
delim = kwargs['delimiter']
print('training on:', train_fname, 'saved as', model_name, 'with', epochs,
'epochs', 'limit', limit)
description = create_and_train_text_model(
train_fname, model_name, epochs, val_size=val_size,
limit=limit, delim=delim)
print('Trained Model Info:')
present(description)
@my_main.command()
@click.argument('inference_set_file')
@click.option('--limit', default=0, help='limit the ammount of inference data')
@click.argument('model_name')
def predict(**kwargs):
"""
Predicts the labels for each line of given file based on the given model
"""
model_name = kwargs['model_name']
inf_fname = kwargs['inference_set_file']
limit = kwargs['limit']
predictions = predict_with_model(model_name, inf_fname, limit=limit)
present(predictions)
def present(h):
print(json.dumps(h,indent=2, ensure_ascii=False))
if __name__ == '__main__':
my_main()