-
Notifications
You must be signed in to change notification settings - Fork 10
/
work_unit.py
executable file
·57 lines (39 loc) · 1.38 KB
/
work_unit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""
A simple class to describe a train/validate dataset and a list of models to train on
that dataset.
"""
import copy
import os
import pickle
DEFAULT_FEATURESET_PARAMS = {
'interpolate': True,
'max_interpolate': 2,
'for_classification': False,
}
class WorkUnit(object):
def __init__(self, name, featureset, featureset_params={}, model_specs=[]):
self.name = name
self.featureset = featureset
self.model_specs = copy.deepcopy(model_specs)
self.work_spec = None
self.featureset_params = DEFAULT_FEATURESET_PARAMS.copy()
self.featureset_params.update(featureset_params)
for model_spec in self.model_specs:
model_spec.work_unit = self
def get_directory(self):
return self.work_spec.get_directory() + self.name + '/'
def get_all_model_dirs(self):
directories = []
for model_spec in self.model_specs:
directories.append(model_spec.get_directory())
return directories
def get_all_results_obj_paths(self):
file_paths = []
for model_spec in self.model_specs:
file_paths.append(model_spec.get_results_obj_path())
return file_paths
def get_all_results_objs(self):
results_objs = []
for model_spec in self.model_specs:
results_objs.append(model_spec.get_results_obj())
return results_objs