Поиск ближайших соседей и их реализация

Я работаю над классификацией простых данных с использованием KNN с евклидовым расстоянием. Я видел пример того, что я хотел бы сделать, это делается с помощью функции MATLAB knnsearch, как показано ниже:

load fisheriris 
x = meas(:,3:4);
gscatter(x(:,1),x(:,2),species)
newpoint = [5 1.45];
[n,d] = knnsearch(x,newpoint,'k',10);
line(x(n,1),x(n,2),'color',[.5 .5 .5],'marker','o','linestyle','none','markersize',10)

Вышеприведенный код принимает новую точку, т.е. [5 1.45], и находит 10 ближайших значений для новой точки. Может ли кто-нибудь показать мне алгоритм MATLAB с подробным объяснением того, что делает функция knnsearch? Есть ли другой способ сделать это?

Ответ 1

Основой алгоритма K-Nearest Neighbor (KNN) является то, что у вас есть матрица данных, состоящая из N строк и M столбцов, где N - это количество точек данных, которые у нас есть, а M - размерность каждой точки данных. Например, если мы поместили декартовы координаты в матрицу данных, это обычно матрица N x 2 или N x 3. Используя эту матрицу данных, вы предоставляете точку запроса и ищете ближайшие k точек в этой матрице данных, которые являются ближайшими к этой точке запроса.

Мы обычно используем евклидово расстояние между запросом и остальными точками в вашей матрице данных, чтобы вычислить наши расстояния. Тем не менее, другие расстояния, такие как L1 или City-Block/Manhattan, также используются. После этой операции у вас будет N евклидовых или манхэттенских расстояний, которые символизируют расстояния между запросами с каждой соответствующей точкой в наборе данных. Найдя их, вы просто ищете k ближайших точек к запросу, сортируя расстояния в порядке возрастания и извлекая те k точек, которые имеют наименьшее расстояние между вашим набором данных и запросом.

Предположим, что ваша матрица данных была сохранена в x, а newpoint - это точка выборки, в которой она имеет M столбцов (т.е. 1 x M), это общая процедура, которой вы должны следовать в форме точек:

  1. Найти евклидово или манхэттенское расстояние между newpoint и каждой точкой в x.
  2. Отсортируйте эти расстояния в порядке возрастания.
  3. Вернуть k точек данных в x которые находятся ближе всего к newpoint.

Давайте делать каждый шаг медленно.


Шаг 1

Один из способов сделать это может быть в цикле for например:

N = size(x,1);
dists = zeros(N,1);
for idx = 1 : N
    dists(idx) = sqrt(sum((x(idx,:) - newpoint).^2));
end

Если бы вы хотели реализовать расстояние до Манхэттена, это было бы просто:

N = size(x,1);
dists = zeros(N,1);
for idx = 1 : N
    dists(idx) = sum(abs(x(idx,:) - newpoint));
end

dists будет N элементным вектором, который содержит расстояния между каждой точкой данных в x и newpoint. Мы делаем поэлементное вычитание между newpoint и точкой данных в x, newpoint в квадрат различия, затем sum их все вместе. Эта сумма затем имеет квадратные корни, что завершает евклидово расстояние. Для Манхэттенского расстояния вы должны выполнить вычитание элемент за элементом, взять абсолютные значения, а затем сложить все компоненты вместе. Это, пожалуй, самая простая из реализаций для понимания, но, возможно, она может быть самой неэффективной... особенно для наборов данных большего размера и большей размерности ваших данных.

Другим возможным решением было бы реплицировать newpoint и сделать эту матрицу того же размера, что и x, затем выполнить поэлементное вычитание этой матрицы, затем суммировать все столбцы для каждой строки и выполнить квадратный корень. Поэтому мы можем сделать что-то вроде этого:

N = size(x, 1);
dists = sqrt(sum((x - repmat(newpoint, N, 1)).^2, 2));

Для Манхэттенского расстояния вы должны сделать:

N = size(x, 1);
dists = sum(abs(x - repmat(newpoint, N, 1)), 2);

repmat берет матрицу или вектор и повторяет их определенное количество раз в заданном направлении. В нашем случае мы хотим взять наш вектор newpoint и сложить это N раз друг над другом, чтобы создать матрицу N x M, где каждая строка имеет длину M элементов. Мы вычитаем эти две матрицы вместе, затем возводим в квадрат каждую компоненту. Как только мы это сделаем, мы sum все столбцы для каждой строки и, наконец, берем квадратный корень из всех результатов. Для Манхэттенского расстояния мы делаем вычитание, берем абсолютное значение и затем суммируем.

Однако, по моему мнению, наиболее эффективный способ сделать это - использовать bsxfun. По сути, это делает репликацию, о которой мы говорили, с помощью одного вызова функции. Поэтому код будет просто так:

dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));

Для меня это выглядит намного чище и по существу. Для Манхэттенского расстояния вы должны сделать:

dists = sum(abs(bsxfun(@minus, x, newpoint)), 2);

Шаг 2

Теперь, когда у нас есть расстояния, мы просто их сортируем. Мы можем использовать sort для сортировки наших расстояний:

[d,ind] = sort(dists);

d будет содержать расстояния, отсортированные в порядке возрастания, в то время как ind говорит вам для каждого значения в несортированном массиве, где оно появляется в отсортированном результате. Нам нужно использовать ind, извлечь первые k элементов этого вектора, а затем использовать ind для индексации в нашей матрице данных x чтобы вернуть те точки, которые были ближе всего к newpoint.

Шаг 3

Последний шаг - вернуть те k точек данных, которые находятся ближе всего к newpoint. Мы можем сделать это очень просто:

ind_closest = ind(1:k);
x_closest = x(ind_closest,:);

ind_closest должен содержать индексы в исходной матрице данных x которые находятся ближе всего к newpoint. В частности, ind_closest содержит данные о том, какие строки необходимо ind_closest в x чтобы получить самые близкие точки к newpoint. x_closest будет содержать эти фактические точки данных.


Для вашего удобства копирования и вставки код выглядит следующим образом:

dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
%// Or do this for Manhattan
% dists = sum(abs(bsxfun(@minus, x, newpoint)), 2);
[d,ind] = sort(dists);
ind_closest = ind(1:k);
x_closest = x(ind_closest,:);

Пробежавшись по вашему примеру, давайте посмотрим на наш код в действии:

load fisheriris 
x = meas(:,3:4);
newpoint = [5 1.45];
k = 10;

%// Use Euclidean
dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
[d,ind] = sort(dists);
ind_closest = ind(1:k);
x_closest = x(ind_closest,:);

ind_closest и x_closest, мы получаем следующее:

>> ind_closest

ind_closest =

   120
    53
    73
   134
    84
    77
    78
    51
    64
    87

>> x_closest

x_closest =

    5.0000    1.5000
    4.9000    1.5000
    4.9000    1.5000
    5.1000    1.5000
    5.1000    1.6000
    4.8000    1.4000
    5.0000    1.7000
    4.7000    1.4000
    4.7000    1.4000
    4.7000    1.5000

Если вы запустили knnsearch, вы увидите, что ваша переменная n совпадает с ind_closest. Однако переменная d возвращает расстояния от newpoint точки до каждой точки x, а не сами точки данных. Если вы хотите реальные расстояния, просто сделайте следующее после кода, который я написал:

dist_sorted = d(1:k);

Обратите внимание, что ответ выше использует только одну точку запроса в пакете из N примеров. Очень часто KNN используется одновременно на нескольких примерах. Предположим, что у нас есть Q точек запроса, которые мы хотим проверить в KNN. В результате получается матрица kx M x Q где для каждого примера или каждого среза мы возвращаем k ближайших точек с размерностью M В качестве альтернативы мы можем вернуть идентификаторы k ближайших точек, что приведет к матрице Q xk. Позвольте вычислить оба.

Наивным способом сделать это было бы применить вышеуказанный код в цикле и цикле над каждым примером.

Примерно так будет работать, когда мы выделяем матрицу Q xk и применяем bsxfun основанный на bsxfun чтобы установить каждую строку выходной матрицы для k ближайших точек в наборе данных, где мы будем использовать набор данных Fisher Iris, как и раньше. Мы также сохраним ту же размерность, что и в предыдущем примере, и я буду использовать четыре примера, поэтому Q = 4 и M = 2:

%// Load the data and create the query points
load fisheriris;
x = meas(:,3:4);
newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5];

%// Define k and the output matrices
Q = size(newpoints, 1);
M = size(x, 2);
k = 10;
x_closest = zeros(k, M, Q);
ind_closest = zeros(Q, k);

%// Loop through each point and do logic as seen above:
for ii = 1 : Q
    %// Get the point
    newpoint = newpoints(ii, :);

    %// Use Euclidean
    dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
    [d,ind] = sort(dists);

    %// New - Output the IDs of the match as well as the points themselves
    ind_closest(ii, :) = ind(1 : k).';
    x_closest(:, :, ii) = x(ind_closest(ii, :), :);
end

Хотя это очень хорошо, мы можем сделать еще лучше. Существует способ эффективно вычислить квадрат евклидова расстояния между двумя наборами векторов. Я оставлю это как упражнение, если вы хотите сделать это с Манхэттеном. Обращаясь к этому блогу, учитывая, что A - это матрица Q1 x M где каждая строка - это точка размерности M с точками Q1 а B - матрица Q2 x M где каждая строка также является точкой размерности M с точками Q2, мы можем эффективно вычислить матрицу расстояний D(i, j) где элемент в строке i и столбце j обозначает расстояние между строкой i в A и строкой j в B используя следующую матричную формулировку:

nA = sum(A.^2, 2); %// Sum of squares for each row of A
nB = sum(B.^2, 2); %// Sum of squares for each row of B
D = bsxfun(@plus, nA, nB.') - 2*A*B.'; %// Compute distance matrix
D = sqrt(D); %// Compute square root to complete calculation

Поэтому, если мы позволим A быть матрицей точек запроса, а B - набором данных, состоящим из ваших исходных данных, мы можем определить k ближайших точек, отсортировав каждую строку по отдельности и определив k местоположений каждой строки, которые были наименьшими. Мы также можем дополнительно использовать это, чтобы получить фактические точки сами.

Следовательно:

%// Load the data and create the query points
load fisheriris;
x = meas(:,3:4);
newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5];

%// Define k and other variables
k = 10;
Q = size(newpoints, 1);
M = size(x, 2);

nA = sum(newpoints.^2, 2); %// Sum of squares for each row of A
nB = sum(x.^2, 2); %// Sum of squares for each row of B
D = bsxfun(@plus, nA, nB.') - 2*newpoints*x.'; %// Compute distance matrix
D = sqrt(D); %// Compute square root to complete calculation 

%// Sort the distances 
[d, ind] = sort(D, 2);

%// Get the indices of the closest distances
ind_closest = ind(:, 1:k);

%// Also get the nearest points
x_closest = permute(reshape(x(ind_closest(:), :).', M, k, []), [2 1 3]);

Мы видим, что мы использовали логику для вычисления матрицы расстояний, но некоторые переменные были изменены в соответствии с примером. Мы также сортируем каждую строку независимо друг от друга, используя две входные версии sort поэтому ind будет содержать идентификаторы для каждой строки, а d будет содержать соответствующие расстояния. Затем мы выясняем, какие индексы являются ближайшими к каждой точке запроса, просто обрезая эту матрицу до k столбцов. Затем мы используем permute и reshape чтобы определить, каковы связанные ближайшие точки. Сначала мы используем все ближайшие индексы и создаем точечную матрицу, которая укладывает все идентификаторы друг на друга, чтобы получить матрицу Q * kx M Использование reshape и permute позволяет нам создать нашу трехмерную матрицу, чтобы она стала матрицей kx M x Q как мы указали. Если вы хотите сами получить фактические расстояния, мы можем индексировать в d и получить то, что нам нужно. Чтобы сделать это, вам нужно будет использовать sub2ind для получения линейных индексов, чтобы мы могли индексировать в d одним выстрелом. Значения ind_closest уже дают нам, к каким столбцам мы должны получить доступ. Строки, к которым нам нужно получить доступ - это просто 1, k раз, 2, k раз и т.д. До Q k для количества точек, которые мы хотели вернуть:

row_indices = repmat((1:Q).', 1, k);
linear_ind = sub2ind(size(d), row_indices, ind_closest);
dist_sorted = D(linear_ind);

Когда мы запускаем приведенный выше код для указанных выше точек запроса, мы получаем следующие индексы, точки и расстояния:

>> ind_closest

ind_closest =

   120   134    53    73    84    77    78    51    64    87
   123   119   118   106   132   108   131   136   126   110
   107    62    86   122    71   127   139   115    60    52
    99    65    58    94    60    61    80    44    54    72

>> x_closest

x_closest(:,:,1) =

    5.0000    1.5000
    6.7000    2.0000
    4.5000    1.7000
    3.0000    1.1000
    5.1000    1.5000
    6.9000    2.3000
    4.2000    1.5000
    3.6000    1.3000
    4.9000    1.5000
    6.7000    2.2000


x_closest(:,:,2) =

    4.5000    1.6000
    3.3000    1.0000
    4.9000    1.5000
    6.6000    2.1000
    4.9000    2.0000
    3.3000    1.0000
    5.1000    1.6000
    6.4000    2.0000
    4.8000    1.8000
    3.9000    1.4000


x_closest(:,:,3) =

    4.8000    1.4000
    6.3000    1.8000
    4.8000    1.8000
    3.5000    1.0000
    5.0000    1.7000
    6.1000    1.9000
    4.8000    1.8000
    3.5000    1.0000
    4.7000    1.4000
    6.1000    2.3000


x_closest(:,:,4) =

    5.1000    2.4000
    1.6000    0.6000
    4.7000    1.4000
    6.0000    1.8000
    3.9000    1.4000
    4.0000    1.3000
    4.7000    1.5000
    6.1000    2.5000
    4.5000    1.5000
    4.0000    1.3000

>> dist_sorted

dist_sorted =

    0.0500    0.1118    0.1118    0.1118    0.1803    0.2062    0.2500    0.3041    0.3041    0.3041
    0.3000    0.3162    0.3606    0.4123    0.6000    0.7280    0.9055    0.9487    1.0198    1.0296
    0.9434    1.0198    1.0296    1.0296    1.0630    1.0630    1.0630    1.1045    1.1045    1.1180
    2.6000    2.7203    2.8178    2.8178    2.8320    2.9155    2.9155    2.9275    2.9732    2.9732

Чтобы сравнить это с knnsearch, вместо этого вы должны указать матрицу точек для второго параметра, где каждая строка является точкой запроса, и вы увидите, что индексы и отсортированные расстояния совпадают между этой реализацией и knnsearch.


Надеюсь, это поможет вам. Удачи!