-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy path01_introduction.py
62 lines (47 loc) · 2.02 KB
/
01_introduction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import mxnet as mx
import numpy as np
import logging
logging.basicConfig(level=logging.INFO)
sample_count = 1000
train_count = 800
valid_count = sample_count - train_count
feature_count = 100
category_count = 10
batch=10
X = mx.nd.uniform(low=0, high=1, shape=(sample_count,feature_count))
Y = mx.nd.empty((sample_count,))
for i in range(0,sample_count-1):
Y[i] = np.random.randint(0,category_count)
X_train = mx.nd.crop(X, begin=(0,0), end=(train_count,feature_count))
Y_train = Y[0:train_count]
X_valid = mx.nd.crop(X, begin=(train_count,0), end=(sample_count,feature_count))
Y_valid = Y[train_count:sample_count]
print(X.shape, Y.shape, X_train.shape, Y_train.shape, X_valid.shape, Y_valid.shape)
# Build network
data = mx.sym.Variable('data')
fc1 = mx.sym.FullyConnected(data, name='fc1', num_hidden=64)
relu1 = mx.sym.Activation(fc1, name='relu1', act_type="relu")
fc2 = mx.sym.FullyConnected(relu1, name='fc2', num_hidden=category_count)
out = mx.sym.SoftmaxOutput(fc2, name='softmax')
mod = mx.mod.Module(out)
# Build iterators
train_iter = mx.io.NDArrayIter(data=X_train,label=Y_train,batch_size=batch)
val_iter = mx.io.NDArrayIter(data=X_valid,label=Y_valid, batch_size=batch)
#for batch in train_iter:
# print batch.data
# print batch.label
# Train model
mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
mod.init_params(initializer=mx.init.Xavier(magnitude=2.))
mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.1), ))
mod.fit(train_iter, num_epoch=50)
pred_count = valid_count
correct_preds = total_correct_preds = 0
#print('batch [labels] [predicted labels] correct predictions')
for preds, i_batch, batch in mod.iter_predict(val_iter):
label = batch.label[0].asnumpy().astype(int)
pred_label = preds[0].asnumpy().argmax(axis=1)
correct_preds = np.sum(pred_label==label)
#print i_batch, label, pred_label, correct_preds
total_correct_preds = total_correct_preds + correct_preds
print('Validation accuracy: %2.2f' % (1.0*total_correct_preds/pred_count))