-
Notifications
You must be signed in to change notification settings - Fork 35
/
Figure4_ParameterRecovery.m
94 lines (73 loc) · 2.29 KB
/
Figure4_ParameterRecovery.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
clear
addpath('./SimulationFunctions')
addpath('./AnalysisFunctions')
addpath('./HelperFunctions')
addpath('./FittingFunctions')
addpath('./LikelihoodFunctions')
global AZred AZblue AZcactus AZsky AZriver AZsand AZmesa AZbrick
AZred = [171,5,32]/256;
AZblue = [12,35,75]/256;
AZcactus = [92, 135, 39]/256;
AZsky = [132, 210, 226]/256;
AZriver = [7, 104, 115]/256;
AZsand = [241, 158, 31]/256;
AZmesa = [183, 85, 39]/256;
AZbrick = [74, 48, 39]/256;
% experiment parameters
T = 1000; % number of trials
mu = [0.2 0.8]; % mean reward of bandits
rng(2);
% Model 3: Rescorla Wagner
for count = 1:1000
alpha = rand;
beta = exprnd(10);
[a, r] = simulate_M3RescorlaWagner_v1(T, mu, alpha, beta);
[Xf, LL, BIC] = fit_M3RescorlaWagner_v1(a, r);
Xsim(1,count) = alpha;
Xsim(2,count) = beta;
Xfit(1,count) = Xf(1);
Xfit(2,count) = Xf(2);
end
%% basic parameter recovery plots
names = {'learning rate' 'softmax temperature'};
symbols = {'\alpha' '\beta'};
figure(1); clf;
set(gcf, 'Position', [811 613 600 300])
[~,~,~,ax] = easy_gridOfEqualFigures([0.2 0.1], [0.1 0.18 0.04]);
for i= 1:size(Xsim,1)
axes(ax(i)); hold on;
plot(Xsim(i,:), Xfit(i,:), 'o', 'color', AZred, 'markersize', 8, 'linewidth', 1)
xl = get(gca, 'xlim');
plot(xl, xl, 'k--')
end
% find 'bad' alpha values
thresh = 0.25;
ind = abs(Xsim(1,:) - Xfit(1,:)) > thresh;
for i = 1:2
axes(ax(i));
plot(Xsim(i,ind), Xfit(i,ind), 'o', 'color', AZblue, 'markersize', 8, 'linewidth', 1, ...
'markerfacecolor', [1 1 1]*0.5)
end
set(ax(1,2),'xscale', 'log', 'yscale' ,'log')
axes(ax(1)); t = title('learning rate');
axes(ax(2)); t(2) = title('softmax temperature');
axes(ax(1)); xlabel('simulated \alpha'); ylabel('fit \alpha');
axes(ax(2)); xlabel('simulated \beta'); ylabel('fit \beta');
set(ax, 'tickdir', 'out', 'fontsize', 18)
set(t, 'fontweight', 'normal')
addABCs(ax(1), [-0.07 0.08], 32)
addABCs(ax(2), [-0.1 0.08], 32, 'B')
set(ax, 'tickdir', 'out')
for i= 1:size(Xsim,1)
axes(ax(i));
xl = get(gca, 'xlim');
plot(xl, xl, 'k--')
end
saveFigurePdf(gcf, '~/Desktop/Figure4')
saveFigureEps(gcf, '~/Desktop/Figure4')
saveFigurePng(gcf, '~/Desktop/Figure4')
%%
figure(1); clf; hold on;
% plot(Xsim(1,:), Xsim(2,:),'.')
plot(Xfit(2,:), Xfit(1,:),'.')
set(gca, 'xscale', 'log')