Сколько Haskell/GHC memoize?

Я написал следующий код для отображения треугольника Паскаля:

import Control.Monad
import Data.List

pascalRow :: Integer -> [Integer]
pascalRow 0 = [1]
pascalRow n = map sumParents pairs
  where previousRow = 0:(pascalRow $ n - 1)++[0]
        pairs = zip previousRow (tail previousRow)
        sumParents (a, b) = a + b

-- Read an integer from stdin, and print the Pascal triangle of that height.
main = do
  n <- readLn
  forM_ [0..n-1] printRow
    where printRow k = putStrLn $ intercalate " " $ map show $ pascalRow k

Игнорируя уродство ++ [0] 1 мне интересно, насколько эффективен этот код. Мне кажется, что есть две возможности.

При вычислении pascalRow n после вычисления всех map pascalRow [1..n-1]:

  • GHC memoizes предыдущие значения, поэтому previousRow вычисляется в постоянное время. (Или возможно O (n) для операции добавления.) Поэтому вычисление pascalRow n занимает только время O (n), а построение всех строк до n (т.е. map pascalRow [1..n]) должно принимать O (n 2).
  • GHC забывает предыдущие значения, поэтому для вычисления previousRow необходимо выполнить весь путь до конца. Кажется, что это должно быть O (n 3) (потому что оно & Sigma; я = 0 → n O (n 2)).

В этом случае, и как я могу улучшить свою реализацию?


1 хотя советы здесь также будут оценены!

Ответ 1

Вы мнимаете функцию, связывая ее с структурой данных, которая "запоминает" прошлые приложения. Ghc не будет помнить произвольные приложения прошлой функции, но он помнит столько, сколько он выработал из структуры, над которой он все еще работает. В этом случае функция pascalRow на самом деле не нужна: мы просто описываем бесконечный треугольник паскаля и печатаем столько, сколько необходимо.

import Control.Monad
import Data.List

pstep :: [Integer] -> [Integer]
pstep xs = zipWith (+) (0:xs) (xs ++ [0])

-- the infinite pascal triangle
pascal = iterate pstep [1] 
pascalRow n = pascal !! n  -- not needed, but fine

-- Read an integer from stdin, 
-- and print that much of the infinite Pascal triangle.
main = do
      n <- readLn
      mapM_ printRow (take n pascal)
  where printRow xs = putStrLn $ intercalate " " $ map show xs