-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrbmrun.c
111 lines (93 loc) · 2.39 KB
/
rbmrun.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
// vim: tabstop=4:softtabstop=4:shiftwidth=4:expandtab:smarttab
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <time.h>
#include <ncurses.h>
#include "include/rbm.h"
#include "config.h"
void dbn_print_curses( dbn_t *obj_in )
{
int i,j;
for(i=0;i<obj_in->nlayer-1;i++)
for(j=0;j<obj_in->layers[i].nv;j++)
mvprintw( i, j, "%d", (int) obj_in->layers[i].vis[j] );
for(j=0;j<obj_in->layers[obj_in->nlayer-2].nh;j++)
mvprintw( obj_in->nlayer - 1, j, "%d", (int) obj_in->layers[obj_in->nlayer-2].hid[j] );
refresh();
}
/**
* This function makes use of the ncurses library to display the
* neural network in real time as it is updated. It also allows
* manual control of time stepping.
*/
void dbn_run( dbn_t *obj_in )
{
/* Assume that the dbn has been properly set up already */
long i,j;
char ch;
/* Initialize ncurses */
initscr();
raw();
keypad( stdscr, TRUE );
noecho();
for(;;)
{
/* Print the output first, then update */
dbn_print_curses( obj_in );
/* Do the updating here */
for(i=0;i<8;i++)
{
rbm_update_hid( &(obj_in->layers[obj_in->nlayer-2]) );
rbm_update_vis( &(obj_in->layers[obj_in->nlayer-2]) );
}
for(i=obj_in->nlayer-3;i>=0;i--)
rbm_update_vis( &(obj_in->layers[i]) );
ch = getch();
if( ch == 'q' )
break;
}
endwin();
}
void usage()
{
printf( "%s\n", PACKAGE_STRING );
printf( "Usage: rbmrun [params] [temp]\n" );
}
int main( int argc, char *argv[] )
{
long i;
double tmp;
dbn_t dbn;
if( argc < 3 )
{
usage();
return 0;
}
/* Call the dbn trainer to do all the shit above in one command */
dbn_load_init( &dbn, argv[1] );
dbn_assemble( &dbn );
for(i=0;i<dbn.nlayer-1;i++)
rbm_set_temp( &dbn.layers[i], 0.1 );
tmp = atof( argv[2] );
for(i=0;i<dbn.nlayer-1;i++)
rbm_set_temp( &dbn.layers[i], tmp );
/* Initialize to random */
for(i=0;i<dbn.layers[0].nv;i++)
if( (double) rand() / (double) RAND_MAX < 0.5 )
dbn.layers[0].vis[i] = (char) 1;
else
dbn.layers[0].vis[i] = (char) 0;
/* Evolve forward */
for(i=0;i<dbn.nlayer-1;i++)
rbm_update_hid( &dbn.layers[i] );
rbm_update_vis( &dbn.layers[dbn.nlayer-2] );
for(i=0;i<512;i++)
{
rbm_update_hid( &dbn.layers[dbn.nlayer-2] );
rbm_update_vis( &dbn.layers[dbn.nlayer-2] );
}
/* Run free via ncurses */
dbn_run( &dbn );
return 0;
}