-
Notifications
You must be signed in to change notification settings - Fork 1
/
demo.m
54 lines (39 loc) · 1.27 KB
/
demo.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
clear;clc;
% path
addpath('auxiliary');
addpath('CTBN');
addpath('liblinear/matlab'); %% !! Liblinear 1.92 or above is required !!
% init
global LR_implementation;
global learn_LR_cost;
LR_implementation = 'liblinear'; % designate the library to use (for logistic regression)
learn_LR_cost = false; % if true, it learns the regularization coefficient on the fly
% sample run
dataset_name = 'emotions';
load(['data/' dataset_name '.mat']);
fprintf('[Training & testing CTBN on ''%s'']\n', dataset_name);
% 10-fold cv
K = 10;
CVO = cvpartition(Y(:,1), 'kfold', K);
CTBN = cell(1, K);
Y_pred_CTBN = cell(1, K);
Y_log_prob_CTBN = cell(1, K);
for r = 1:CVO.NumTestSets
fprintf('msg: round %d/%d... ', r, CVO.NumTestSets);
tic;
X_tr = X(CVO.training(r), :);
Y_tr = Y(CVO.training(r), :);
X_ts = X(CVO.test(r), :);
Y_ts = Y(CVO.test(r), :);
% CTBN [Batal, Hong, Hauskrecht 2013]
% train
CTBN_model = learn_output_tree_sw(X_tr, Y_tr);
% test
[ Y_pred_CTBN{r}, Y_log_prob_CTBN{r}] = MAP_prediction_sw(CTBN_model, X_ts, Y_ts);
toc;
% bookkeeping
CTBN{r} = getMeasuresMLC(Y_ts, Y_pred_CTBN{r}, Y_log_prob_CTBN{r});
end
% report results
fprintf('\n[Test results on ''%s'']\n', dataset_name);
process_results(CTBN);