Оптимизируйте функцию списка, которая создает слишком много мусора (не переполнение стека)

У меня есть функция Haskell, которая вызывает более 50% всех распределений моей программы, в результате чего 60% времени моего запуска будет принято GC. Я запускаю небольшой стек (-K10K), поэтому нет, но могу ли я сделать эту функцию быстрее, с меньшим распределением?

Цель здесь - вычислить произведение матрицы на вектор. Я не могу использовать hmatrix, потому что это часть большей функции с помощью ad Автоматическое дифференцирование, поэтому мне нужно использовать списки Num. Во время выполнения я полагаю, что использование модуля Numeric.AD означает, что мои типы должны быть Scalar Double.

listMProd :: (Num a) => [a] -> [a] -> [a]
listMProd mdt vdt = go mdt vdt 0
  where
    go [] _  s = [s]
    go ls [] s = s : go ls vdt 0
    go (y:ys) (x:xs) ix = go ys xs (y*x+ix)

В основном мы прокручиваем матрицу, умножая и добавляя аккумулятор, пока не дойдем до конца вектора, сохраняя результат, а затем продолжаем снова перезапускать вектор. У меня есть тест quickcheck, подтверждающий, что я получаю тот же результат, что и матричное/векторное произведение в hmatrix.

Я пробовал с foldl, foldr и т.д. Ничего из того, что я пробовал, делает функцию быстрее (и некоторые вещи, такие как foldr, вызывают утечку пространства).

Запуск с профилированием говорит мне, помимо того, что эта функция используется там, где большая часть времени и распределения тратится, что есть грузы Cells, Cells является типом данных из пакета ad.

Простой тест для запуска:

import Numeric.AD

main = do
    let m :: [Double] = replicate 400 0.2
        v :: [Double] = replicate 4 0.1
        mycost v m = sum $ listMProd m v 
        mygrads = gradientDescent (mycost (map auto v)) (map auto m)
    print $ mygrads !! 1000

Это на моей машине говорит, что GC занят 47% времени.

Любые идеи?

Ответ 1

Очень простая оптимизация заключается в том, чтобы сделать функцию go строгой по ее параметру аккумулятора, потому что она мала, может быть распакована, если a является примитивной и всегда нуждается в полной оценке:

{-# LANGUAGE BangPatterns #-}
listMProd :: (Num a) => [a] -> [a] -> [a]
listMProd mdt vdt = go mdt vdt 0
  where
    go [] _  !s = [s]
    go ls [] !s = s : go ls vdt 0
    go (y:ys) (x:xs) !ix = go ys xs (y*x+ix)

На моей машине он дает 3-4 раза ускорение (скомпилировано с -O2).

С другой стороны, промежуточные списки не должны быть строгими, чтобы они могли быть сплавлены.