-
Notifications
You must be signed in to change notification settings - Fork 0
/
sample_experiment.py
95 lines (76 loc) · 3.63 KB
/
sample_experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from mlpipeline.entities import ExecutionModeKeys
from mlpipeline import (Versions,
MetricContainer,
iterator)
from mlpipeline.base import (ExperimentABC,
DataLoaderABC,
DataLoaderCallableWrapper)
class An_ML_Model():
def __init__(self, hyperparameter="default value"):
self.hyperparameter = hyperparameter
def train(self):
return "Trained using {}".format(self.hyperparameter)
class TestingDataLoader(DataLoaderABC):
def __init__(self):
self.log("creating dataloader")
def get_train_sample_count(self):
return 1000
def get_test_sample_count(self):
return 1000
def get_train_input(self, **kargs):
return lambda: "got input form train input function"
def get_test_input(self):
return lambda: "got input form test input function"
class TestingExperiment(ExperimentABC):
def __init__(self, versions, **args):
super().__init__(versions, **args)
def setup_model(self, ):
self.model = An_ML_Model()
self.model.hyperparameter = self.current_version["hyperparameter"]
def pre_execution_hook(self, mode=ExecutionModeKeys.TEST):
self.log("Pre execution")
self.log("Version spec: {}".format(self.current_version))
self.log(f"Experiment dir: {self.experiment_dir}")
self.log(f"Dataloader: {self.dataloader}")
self.current_version = self.current_version
def train_loop(self, input_fn):
metric_container = MetricContainer(metrics=['1', 'b', 'c'], track_average_epoch_count=5)
metric_container = MetricContainer(metrics=[{'metrics': ['a', 'b', 'c']},
{'metrics': ['2', 'd', 'e'],
'track_average_epoch_count': 10}],
track_average_epoch_count=5)
self.log("calling input fn")
input_fn()
for epoch in iterator(range(6)):
for idx in iterator(range(6), 2):
metric_container.a.update(idx)
metric_container.b.update(idx*2)
self.log("Epoch: {} step: {}".format(epoch, idx))
self.log("a {}".format(metric_container.a.avg()))
self.log("b {}".format(metric_container.b.avg()))
if idx % 3 == 0:
metric_container.reset()
metric_container.log_metrics(['a', '2'])
metric_container.reset_epoch()
metric_container.log_metrics()
self.log("trained: {}".format(self.model.train()))
self.copy_related_files("experiments/exports")
def evaluate_loop(self, input_fn):
self.log("calling input fn")
input_fn()
metrics = MetricContainer(['a', 'b'])
metrics.a.update(10, 1)
metrics.b.update(2, 1)
return metrics
def export_model(self):
self.log("YAY! Exported!")
dl = DataLoaderCallableWrapper(TestingDataLoader)
v = Versions(dl, 1, 10, learning_rate=0.01)
v.add_version("version1", hyperparameter="a hyperparameter")
v.add_version("version2", custom_paramters={"hyperparameter": None})
v.add_version("version3", custom_paramters={"hyperparameter": None})
v.add_version("version4", custom_paramters={"hyperparameter": None})
v.filter_versions(blacklist_versions=["version3"])
v.filter_versions(whitelist_versions=["version1", "version2"])
v.add_version("version5", custom_paramters={"hyperparameter": None})
EXPERIMENT = TestingExperiment(versions=v)