-
Notifications
You must be signed in to change notification settings - Fork 13
/
tutorial_linearModel.m
173 lines (145 loc) · 9.12 KB
/
tutorial_linearModel.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
% Demo code for how to use the linear encoding model used in the study
% ‘Single-trial neural dynamics are dominated by richly varied movements’
% by Musall, Kaufman et al., 2019.
%
% This code shows how to build a design matrix based on task and movement
% events, runs the linear model, shows how to analyze the fitted beta weight
% and quantify cross-validated explained variance
%
% To run, add the repo to your matlab path and download our widefield dataset
% from the CSHL repository, under http://repository.cshl.edu/38599/.
% Download the data folder 'Widefield' and navigate to its location in Matlab.
%% get some data from example recording
animal = 'mSM43'; %example animal
rec = '23-Nov-2017'; %example recording
fPath = [pwd filesep 'Widefield' filesep animal filesep rec filesep]; %path to demo recording
load([fPath 'opts2.mat'], 'opts'); % load some options
load('allenDorsalMapSM.mat', 'dorsalMaps'); % load allen atlas info
load([fPath 'Vc.mat'], 'U'); %load spatial components
load([fPath 'interpVc.mat'], 'Vc', 'frames'); %load adjusted temporal components
mask = squeeze(isnan(U(:,:,1)));
allenMask = dorsalMaps.allenMask;
%load behavioral data
bhvFile = dir([fPath '*SpatialDisc*.mat']);
load([fPath bhvFile.name]);
%% asign some basic options for the model
% There are three event types when building the design matrix.
% Event type 1 will span the rest of the current trial. Type 2 will span
% frames according to sPostTime. Type 3 will span frames before the event
% according to mPreTime and frames after the event according to mPostTime.
opts.sPostTime = ceil(6 * opts.frameRate); % follow stim events for sPostStim in frames (used for eventType 2)
opts.mPreTime = ceil(0.5 * opts.frameRate); % precede motor events to capture preparatory activity in frames (used for eventType 3)
opts.mPostTime = ceil(2 * opts.frameRate); % follow motor events for mPostStim in frames (used for eventType 3)
opts.framesPerTrial = frames; % nr. of frames per trial
opts.folds = 10; %nr of folds for cross-validation
%% get some events
load([fPath 'orgRegData.mat'], 'fullR', 'recLabels', 'recIdx', 'idx'); %load design matrix to isolate example events and video data (refer to the code 'delayDec_RegressModel' to see how this was generated in the paper)
recIdx(idx) = []; %reject non-used regressors
vidR = fullR(:,end-399:end); %last 400 PCs are video components
% task events
taskLabels = {'time' 'lVisStim' 'rVisStim' 'Choice' 'prevReward'}; %some task variables
taskEventType = [1 2 2 1 1]; %different type of events.
taskEvents(:,1) = fullR(:,find(recIdx == find(ismember(recLabels, taskLabels(1))),1)); %find time regressor. This happens every first frame in every trial.
taskEvents(:,2) = fullR(:,find(recIdx == find(ismember(recLabels,taskLabels(2))),1)); %find event regressor for left visual stimulus
taskEvents(:,3) = fullR(:,find(recIdx == find(ismember(recLabels,taskLabels(3))),1)); %find event regressor for right visual stimulus
taskEvents(:,4) = fullR(:,find(recIdx == find(ismember(recLabels,taskLabels(4))),1)); %find choice reressor. This is true when the animal responded on the left.
taskEvents(:,5) = fullR(:,find(recIdx == find(ismember(recLabels,taskLabels(5))),1)); %find previous reward regressor. This is true when previous trial was rewarded.
% movement events
moveLabels = {'lGrab' 'rGrab' 'lLick' 'rLick' 'nose' 'whisk'}; %some movement variables
moveEventType = [3 3 3 3 3 3]; %different type of events. these are all peri-event variables.
for x = 1 : length(moveLabels)
moveEvents(:,x) = fullR(:,find(recIdx == find(ismember(recLabels, moveLabels(x))),1)+15); %find movement regressor.
end
clear fullR %clear old design matrix
% make design matrix
[taskR, taskIdx] = makeDesignMatrix(taskEvents, taskEventType, opts); %make design matrix for task variables
[moveR, moveIdx] = makeDesignMatrix(moveEvents, moveEventType, opts); %make design matrix for movement variables
fullR = [taskR, moveR, vidR]; %make new, single design matrix
moveLabels = [moveLabels, {'video'}];
regIdx = [taskIdx; moveIdx + max(taskIdx); repmat(max(moveIdx)+max(taskIdx)+1, size(vidR,2), 1)]; %regressor index
regLabels = [taskLabels, moveLabels];
%% run QR and check for rank-defficiency. This will show whether a given regressor is highly collinear with other regressors in the design matrix.
% The resulting plot ranges from 0 to 1 for each regressor, with 1 being
% fully orthogonal to all preceeding regressors in the matrix and 0 being
% fully redundant. Having fully redundant regressors in the matrix will
% break the model, so in this example those regressors are removed. In
% practice, you should understand where the redundancy is coming from and
% change your model design to avoid it in the first place!
rejIdx = false(1,size(fullR,2));
[~, fullQRR] = qr(bsxfun(@rdivide,fullR,sqrt(sum(fullR.^2))),0); %orthogonalize normalized design matrix
figure; plot(abs(diag(fullQRR)),'linewidth',2); ylim([0 1.1]); title('Regressor orthogonality'); drawnow; %this shows how orthogonal individual regressors are to the rest of the matrix
axis square; ylabel('Norm. vector angle'); xlabel('Regressors');
if sum(abs(diag(fullQRR)) > max(size(fullR)) * eps(fullQRR(1))) < size(fullR,2) %check if design matrix is full rank
temp = ~(abs(diag(fullQRR)) > max(size(fullR)) * eps(fullQRR(1)));
fprintf('Design matrix is rank-defficient. Removing %d/%d additional regressors.\n', sum(temp), sum(~rejIdx));
rejIdx(~rejIdx) = temp; %reject regressors that cause rank-defficint matrix
end
% save([fPath filesep 'regData.mat'], 'fullR', 'regIdx', 'regLabels','fullQRR','-v7.3'); %save some model variables
%% fit model to imaging data
[ridgeVals, dimBeta] = ridgeMML(Vc', fullR, true); %get ridge penalties and beta weights.
% save([fPath 'dimBeta.mat'], 'dimBeta', 'ridgeVals'); %save beta kernels
%reconstruct imaging data and compute R^2
Vm = (fullR * dimBeta)';
corrMat = modelCorr(Vc,Vm,U) .^2; %compute explained variance
corrMat = arrayShrink(corrMat,mask,'split'); %recreate full frame
corrMat = alignAllenTransIm(corrMat,opts.transParams); %align to allen atlas
corrMat = corrMat(:, 1:size(allenMask,2));
%% check beta kernels
% select variable of interest. Must be included in 'regLabels'.
cVar = 'rVisStim';
% cVar = 'rGrab';
% cVar = 'whisk';
% find beta weights for current variable
cIdx = regIdx == find(ismember(regLabels,cVar));
U = reshape(U, [], size(Vc,1));
cBeta = U * dimBeta(cIdx, :)';
cBeta = reshape(cBeta, size(mask,1), size(mask,2), []);
U = reshape(U, size(mask,1), size(mask,2), size(Vc,1));
compareMovie(cBeta)
%% run cross-validation
%full model - this will take a moment
[Vfull, fullBeta, ~, fullIdx, fullRidge, fullLabels] = crossValModel(fullR, Vc, regLabels, regIdx, regLabels, opts.folds);
save([fPath 'cvFull.mat'], 'Vfull', 'fullBeta', 'fullR', 'fullIdx', 'fullRidge', 'fullLabels'); %save some results
fullMat = modelCorr(Vc,Vfull,U) .^2; %compute explained variance
fullMat = arrayShrink(fullMat,mask,'split'); %recreate full frame
fullMat = alignAllenTransIm(fullMat,opts.transParams); %align to allen atlas
fullMat = fullMat(:, 1:size(allenMask,2));
%task model alone - this will take a moment
[Vtask, taskBeta, taskR, taskIdx, taskRidge, taskLabels] = crossValModel(fullR, Vc, taskLabels, regIdx, regLabels, opts.folds);
save([fPath 'cvTask.mat'], 'Vtask', 'taskBeta', 'taskR', 'taskIdx', 'taskRidge', 'taskLabels'); %save some results
taskMat = modelCorr(Vc,Vtask,U) .^2; %compute explained variance
taskMat = arrayShrink(taskMat,mask,'split'); %recreate task frame
taskMat = alignAllenTransIm(taskMat,opts.transParams); %align to allen atlas
taskMat = taskMat(:, 1:size(allenMask,2));
%movement model alone - this will take a moment
[Vmove, moveBeta, moveR, moveIdx, moveRidge, moveLabels] = crossValModel(fullR, Vc, moveLabels, regIdx, regLabels, opts.folds);
save([fPath 'cvMove.mat'], 'Vmove', 'moveBeta', 'moveR', 'moveIdx', 'moveRidge', 'moveLabels'); %save some results
moveMat = modelCorr(Vc,Vmove,U) .^2; %compute explained variance
moveMat = arrayShrink(moveMat,mask,'split'); %recreate move frame
moveMat = alignAllenTransIm(moveMat,opts.transParams); %align to allen atlas
moveMat = moveMat(:, 1:size(allenMask,2));
%% show R^2 results
%cross-validated R^2
figure;
subplot(1,3,1);
mapImg = imshow(fullMat,[0 0.75]);
colormap(mapImg.Parent,'inferno'); axis image; title('cVR^2 - Full model');
set(mapImg,'AlphaData',~isnan(mapImg.CData)); %make NaNs transparent.
subplot(1,3,2);
mapImg = imshow(taskMat,[0 0.75]);
colormap(mapImg.Parent,'inferno'); axis image; title('cVR^2 - Task model');
set(mapImg,'AlphaData',~isnan(mapImg.CData)); %make NaNs transparent.
subplot(1,3,3);
mapImg = imshow(moveMat,[0 0.75]);
colormap(mapImg.Parent,'inferno'); axis image; title('cVR^2 - Movement model');
set(mapImg,'AlphaData',~isnan(mapImg.CData)); %make NaNs transparent.
%unique R^2
figure;
subplot(1,2,1);
mapImg = imshow(fullMat - moveMat,[0 0.4]);
colormap(mapImg.Parent,'inferno'); axis image; title('deltaR^2 - Task model');
set(mapImg,'AlphaData',~isnan(mapImg.CData)); %make NaNs transparent.
subplot(1,2,2);
mapImg = imshow(fullMat - taskMat,[0 0.4]);
colormap(mapImg.Parent,'inferno'); axis image; title('deltaR^2 - Movement model');
set(mapImg,'AlphaData',~isnan(mapImg.CData)); %make NaNs transparent.