Matlab/octave - Обобщенное матричное умножение

Я хотел бы сделать функцию для обобщения умножения матрицы. В принципе, он должен иметь возможность выполнять стандартное умножение матрицы, но он должен позволять изменять два бинарных оператора product/sum любой другой функцией.

Цель должна быть максимально эффективной, как с точки зрения процессора, так и с точки зрения памяти. Конечно, он всегда будет менее эффективным, чем A * B, но гибкость операторов здесь.

Вот несколько команд, которые я мог придумать после прочтения различных интересных темы:

A = randi(10, 2, 3);
B = randi(10, 3, 4);

% 1st method
C = sum(bsxfun(@mtimes, permute(A,[1 3 2]),permute(B,[3 2 1])), 3)
% Alternative: C = bsxfun(@(a,b) mtimes(a',b), A', permute(B, [1 3 2]))

% 2nd method
C = sum(bsxfun(@(a,b) a*b, permute(A,[1 3 2]),permute(B,[3 2 1])), 3)

% 3rd method (Octave-only)
C = sum(permute(A, [1 3 2]) .* permute(B, [3 2 1]), 3)

% 4th method (Octave-only): multiply nxm A with nx1xd B to create a nxmxd array
C = bsxfun(@(a, b) sum(times(a,b)), A', permute(B, [1 3 2]));
C = C2 = squeeze(C(1,:,:)); % sum and turn into mxd

Проблема с методами 1-3 состоит в том, что они будут генерировать n матриц, прежде чем сворачивать их с помощью sum(). 4 лучше, потому что он выполняет сумму() внутри bsxfun, но bsxfun все еще генерирует n матриц (за исключением того, что они в основном пусты, содержащие только вектор значений non-zeros, являющихся суммами, остальное заполняется 0, чтобы соответствовать требования к размерам).

Что бы я хотел, это что-то вроде 4-го метода, но без бесполезной памяти 0.

Любая идея?

Ответ 1

Вот немного более отполированная версия решения, которое вы опубликовали, с небольшими улучшениями.

Мы проверяем, есть ли у нас больше строк, чем столбцы или наоборот, и затем умножаем соответственно, выбирая либо умножать строки с матрицами или матрицами со столбцами (тем самым делая наименьшее количество итераций цикла).

A*B

Примечание. Это может быть не всегда лучшая стратегия (переход по строкам вместо столбцов), даже если число строк меньше, чем столбцов; тот факт, что массивы MATLAB хранятся в порядковый номер столбца в памяти, делает его более эффективным для среза по столбцам, поскольку элементы хранятся последовательно, В то время как доступ к строкам включает в себя перемещение элементов strides (что не кэширует - думаю пространственная локальность).

Кроме этого, код должен обрабатывать двойной/одиночный, реальный/сложный, полный/разреженный (и ошибки, где это не возможная комбинация). Он также учитывает пустые матрицы и нулевые размеры.

function C = my_mtimes(A, B, outFcn, inFcn)
    % default arguments
    if nargin < 4, inFcn = @times; end
    if nargin < 3, outFcn = @sum; end

    % check valid input
    assert(ismatrix(A) && ismatrix(B), 'Inputs must be 2D matrices.');
    assert(isequal(size(A,2),size(B,1)),'Inner matrix dimensions must agree.');
    assert(isa(inFcn,'function_handle') && isa(outFcn,'function_handle'), ...
        'Expecting function handles.')

    % preallocate output matrix
    M = size(A,1);
    N = size(B,2);
    if issparse(A)
        args = {'like',A};
    elseif issparse(B)
        args = {'like',B};
    else
        args = {superiorfloat(A,B)};
    end
    C = zeros(M,N, args{:});

    % compute matrix multiplication
    % http://en.wikipedia.org/wiki/Matrix_multiplication#Inner_product
    if M < N
        % concatenation of products of row vectors with matrices
        % A*B = [a_1*B ; a_2*B ; ... ; a_m*B]
        for m=1:M
            %C(m,:) = A(m,:) * B;
            %C(m,:) = sum(bsxfun(@times, A(m,:)', B), 1);
            C(m,:) = outFcn(bsxfun(inFcn, A(m,:)', B), 1);
        end
    else
        % concatenation of products of matrices with column vectors
        % A*B = [A*b_1 , A*b_2 , ... , A*b_n]
        for n=1:N
            %C(:,n) = A * B(:,n);
            %C(:,n) = sum(bsxfun(@times, A, B(:,n)'), 2);
            C(:,n) = outFcn(bsxfun(inFcn, A, B(:,n)'), 2);
        end
    end
end

Сравнение

Функция, без сомнения, медленная повсюду, но для больших размеров она на порядок хуже, чем встроенное матричное умножение:

        (tic/toc times in seconds)
      (tested in R2014a on Windows 8)

    size      mtimes       my_mtimes 
    ____    __________     _________
     400     0.0026398       0.20282
     600      0.012039       0.68471
     800      0.014571        1.6922
    1000      0.026645        3.5107
    2000       0.20204         28.76
    4000        1.5578        221.51

mtimes_vs_mymtimes

Вот тестовый код:

sz = [10:10:100 200:200:1000 2000 4000];
t = zeros(numel(sz),2);
for i=1:numel(sz)
    n = sz(i); disp(n)
    A = rand(n,n);
    B = rand(n,n);

    tic
    C = A*B;
    t(i,1) = toc;
    tic
    D = my_mtimes(A,B);
    t(i,2) = toc;

    assert(norm(C-D) < 1e-6)
    clear A B C D
end

semilogy(sz, t*1000, '.-')
legend({'mtimes','my_mtimes'}, 'Interpreter','none', 'Location','NorthWest')
xlabel('Size N'), ylabel('Time [msec]'), title('Matrix Multiplication')
axis tight

Дополнительные

Для полноты ниже приведены еще два наивных способа реализации обобщенного умножения матрицы (если вы хотите сравнить производительность, замените последнюю часть функции my_mtimes любым из них). Я даже не собираюсь публиковать свои прошедшие времена:)

C = zeros(M,N, args{:});
for m=1:M
    for n=1:N
        %C(m,n) = A(m,:) * B(:,n);
        %C(m,n) = sum(bsxfun(@times, A(m,:)', B(:,n)));
        C(m,n) = outFcn(bsxfun(inFcn, A(m,:)', B(:,n)));
    end
end

И еще один способ (с тройным циклом):

C = zeros(M,N, args{:});
P = size(A,2); % = size(B,1);
for m=1:M
    for n=1:N
        for p=1:P
            %C(m,n) = C(m,n) + A(m,p)*B(p,n);
            %C(m,n) = plus(C(m,n), times(A(m,p),B(p,n)));
            C(m,n) = outFcn([C(m,n) inFcn(A(m,p),B(p,n))]);
        end
    end
end

Что делать дальше?

Если вы хотите выжать больше производительности, вам нужно перейти в MEX файл C/С++, чтобы сократить накладные расходы на интерпретируемый код MATLAB. Вы можете использовать оптимизированные процедуры BLAS/LAPACK, вызвав их из MEX файлов (см. вторую часть этого сообщения для примера). MATLAB поставляется с библиотекой Intel MKL, которую, честно говоря, вы не можете победить, когда речь заходит о вычислениях линейной алгебры на процессорах Intel.

Другие уже упомянули пару заявок на Файловый обмен, которые реализуют универсальные матричные процедуры как MEX файлы (см. @natan ответ). Это особенно эффективно, если вы связываете их с оптимизированной библиотекой BLAS.

Ответ 2

Почему бы просто не использовать bsxfun возможность принимать произвольную функцию?

C = shiftdim(feval(f, (bsxfun(g, A.', permute(B,[1 3 2])))), 1);

Здесь

  • f - внешняя функция (соответствующая сумме в случае матричного умножения). Он должен принять трехмерный массив произвольного размера m x n x p и работать по его столбцам, чтобы вернуть массив 1 x m x p.
  • g - это внутренняя функция (соответствующая произведению в случае матричного умножения). В соответствии с bsxfun он должен принимать в качестве входных данных либо два вектора столбца того же размера, либо один вектор-столбец и один скаляр, и возвращать в качестве вывода вектор-столбец того же размера, что и вход (-ы).

Это работает в Matlab. Я не тестировал в Octave.


Пример 1: Матричное умножение:

>> f = @sum;   %// outer function: sum
>> g = @times; %// inner function: product
>> A = [1 2 3; 4 5 6];
>> B = [10 11; -12 -13; 14 15];
>> C = shiftdim(feval(f, (bsxfun(g, A.', permute(B,[1 3 2])))), 1)
C =
    28    30
    64    69

Check:

>> A*B
ans =
    28    30
    64    69

Пример 2. Рассмотрим две приведенные выше матрицы с

>> f = @(x,y) sum(abs(x));     %// outer function: sum of absolute values
>> g = @(x,y) max(x./y, y./x); %// inner function: "symmetric" ratio
>> C = shiftdim(feval(f, (bsxfun(g, A.', permute(B,[1 3 2])))), 1)
C =
   14.8333   16.1538
    5.2500    5.6346

Проверить: вручную вычислить C(1,2):

>> sum(abs( max( (A(1,:))./(B(:,2)).', (B(:,2)).'./(A(1,:)) ) ))
ans =
   16.1538

Ответ 3

Без погружения в детали есть такие инструменты, как mtimesx и MMX, которые являются быстрыми операциями общей матрицы и скалярных операций общего назначения. Вы можете изучить их код и адаптировать их к вашим потребностям. Скорее всего, это будет быстрее, чем matlab bsxfun.

Ответ 4

После изучения нескольких функций обработки, таких как bsxfun, кажется, что невозможно использовать прямое матричное умножение с помощью этих (то, что я подразумеваю под прямым, состоит в том, что временные продукты не хранятся в памяти, а суммируются как ASAP, а затем другие потому что они имеют выход фиксированного размера (либо тот же, что и вход, либо с расширением singles bsxfun, декартовым произведением размеров двух входов). Однако можно немного обмануть Octave (что не работает с MatLab, который проверяет размеры вывода):

C = bsxfun(@(a,b) sum(bsxfun(@times, a, B))', A', sparse(1, size(A,1)))
C = bsxfun(@(a,b) sum(bsxfun(@times, a, B))', A', zeros(1, size(A,1), 2))(:,:,2)

Однако не используйте их, потому что полученные значения не являются надежными (Octave может отменить или даже удалить их и вернуть 0!).

Итак, теперь я просто реализую полу-векторную версию, здесь моя функция:

function C = genmtimes(A, B, outop, inop)
% C = genmtimes(A, B, inop, outop)
% Generalized matrix multiplication between A and B. By default, standard sum-of-products matrix multiplication is operated, but you can change the two operators (inop being the element-wise product and outop the sum).
% Speed note: about 100-200x slower than A*A' and about 3x slower when A is sparse, so use this function only if you want to use a different set of inop/outop than the standard matrix multiplication.

if ~exist('inop', 'var')
    inop = @times;
end

if ~exist('outop', 'var')
    outop = @sum;
end

[n, m] = size(A);
[m2, o] = size(B);

if m2 ~= m
    error('nonconformant arguments (op1 is %ix%i, op2 is %ix%i)\n', n, m, m2, o);
end


C = [];
if issparse(A) || issparse(B)
    C = sparse(o,n);
else
    C = zeros(o,n);
end

A = A';
for i=1:n
    C(:,i) = outop(bsxfun(inop, A(:,i), B))';
end
C = C';

end

Протестировано как с разреженными, так и с нормальными матрицами: разрыв производительности намного меньше с разреженными матрицами (в 3 раза медленнее), чем с нормальными матрицами (~ 100x медленнее).

Я думаю, что это медленнее, чем реализация bsxfun, но, по крайней мере, он не переполняет память:

A = randi(10, 1000);
C = genmtimes(A, A');

Если кто-нибудь лучше предложить, я все равно ищу лучшую альтернативу.