-
Notifications
You must be signed in to change notification settings - Fork 0
/
kNearestNeighbors.m
43 lines (40 loc) · 1.59 KB
/
kNearestNeighbors.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
function [iknn, dist] = kNearestNeighbors(XTr, XTe, k, M)
%%=========================================================================
% Find k nearest neighbors for each testing instance in the training set
% INPUT:
% XTr : training examples by columns
% XTe : testing examples by columns
% k : number of nearest neighbors
% OUPUT:
% iknn: index of k nearest neighbors by columns for each instance
% dist: squared Euclidean distance of k nearest neighbors by columns
%==========================================================================
% author: Bac Nguyen Cong
% email : [email protected]
%%=========================================================================
if ~exist('M', 'var')
M = eye(size(XTr,1));
end
nTrains = size(XTr,2);
nTests = size(XTe,2);
maxBlocks = 700;
% check number of nearest neighbors
if k > nTrains
k = nTrains;
warning('k is bigger than training examples'); %#ok<WNTAG>
end
iknn = zeros(k, nTests);
dist = zeros(k, nTests);
X = sum(XTr.*(M*XTr), 1)';
% compute by each block
for i=1:maxBlocks:nTests
addBlocks = min(maxBlocks,nTests-i+1)-1;
x = XTe(:,i:i+addBlocks);
Dist = bsxfun(@plus,X, bsxfun(@plus, sum(x.*(M*x)), - 2*XTr'*(M*x)));
[v, ind] = sort(Dist,1); clear('Dist');
% save result
iknn(:,i:i+addBlocks) = ind(1:k,:); clear('ind');
dist(:,i:i+addBlocks) = v(1:k,:); clear('v');
end
end
%%=======================================================================