-
Notifications
You must be signed in to change notification settings - Fork 181
/
evaluate.py
82 lines (69 loc) · 3.77 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# coding: UTF-8
'''''''''''''''''''''''''''''''''''''''''''''''''''''
file name: model.py
create time: 2018年06月11日 星期一 13时53分42秒
author: Jipeng Huang
e-mail: [email protected]
github: https://github.com/hjptriplebee
'''''''''''''''''''''''''''''''''''''''''''''''''''''
#evalute model, just for test
import data
from model import *
class EVALUATE_MODEL(MODEL):
"""evaluate model class"""
def evaluate(self, reload=True):
"""evaluate model"""
print("training...")
gtX = tf.placeholder(tf.int32, shape=[batchSize, None]) # input
gtY = tf.placeholder(tf.int32, shape=[batchSize, None]) # output
logits, probs, a, b, c = self.buildModel(self.trainData.wordNum, gtX)
targets = tf.reshape(gtY, [-1])
# loss
loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example([logits], [targets],
[tf.ones_like(targets, dtype=tf.float32)])
globalStep = tf.Variable(0, trainable=False)
addGlobalStep = globalStep.assign_add(1)
cost = tf.reduce_mean(loss)
trainableVariables = tf.trainable_variables()
grads, a = tf.clip_by_global_norm(tf.gradients(cost, trainableVariables), 5) # prevent loss divergence caused by gradient explosion
learningRate = tf.train.exponential_decay(learningRateBase, global_step=globalStep,
decay_steps=learningRateDecayStep, decay_rate=learningRateDecayRate)
optimizer = tf.train.AdamOptimizer(learningRate)
trainOP = optimizer.apply_gradients(zip(grads, trainableVariables))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
if not os.path.exists(evaluateCheckpointsPath):
os.mkdir(evaluateCheckpointsPath)
if reload:
checkPoint = tf.train.get_checkpoint_state(evaluateCheckpointsPath)
# if have checkPoint, restore checkPoint
if checkPoint and checkPoint.model_checkpoint_path:
saver.restore(sess, checkPoint.model_checkpoint_path)
print("restored %s" % checkPoint.model_checkpoint_path)
else:
print("no checkpoint found!")
for epoch in range(epochNum):
X, Y = self.trainData.generateBatch()
epochSteps = len(X) # equal to batch
for step, (x, y) in enumerate(zip(X, Y)):
a, loss, gStep = sess.run([trainOP, cost, addGlobalStep], feed_dict={gtX: x, gtY: y})
print("epoch: %d, steps: %d/%d, loss: %3f" % (epoch + 1, step + 1, epochSteps, loss))
if gStep % saveStep == saveStep - 1: # prevent save at the beginning
print("save model")
saver.save(sess, os.path.join(evaluateCheckpointsPath, type), global_step=gStep)
X, Y = self.trainData.generateBatch(isTrain=False)
print("evaluating testing error...")
wrongNum = 0
totalNum = 0
testBatchNum = len(X)
for step, (x, y) in enumerate(zip(X, Y)):
print("test batch %d/%d" % (step + 1, testBatchNum))
testProbs, testTargets = sess.run([probs, targets], feed_dict={gtX: x, gtY: y})
wrongNum += len(np.nonzero(np.argmax(testProbs, axis=1) - testTargets)[0])
totalNum += len(testTargets)
print("accuracy: %.2f" % ((totalNum - wrongNum) / totalNum))
if __name__ == "__main__":
trainData = data.POEMS(trainPoems, isEvaluate=True)
MCPangHu = EVALUATE_MODEL(trainData)
MCPangHu.evaluate()