Scala Рекурсия против цикла: производительность и время выполнения

Я написал наивный тестовый слой, чтобы измерить производительность трех видов факториальной реализации: на основе петли, не хвост-рекурсивный и хвостовой рекурсивный.

Удивительно для меня худшим исполнителем был цикл ( "пока" ожидалось, что он будет более эффективным, поэтому я предоставил оба), которые стоят почти в два раза больше, чем хвостовая рекурсивная альтернатива.

ANSWER: исправление реализации цикла, исключающее оператор * =, который превосходит худшее с BigInt из-за его внутренних "петель", стал самым быстрым, как ожидалось

Другим поведением "woodoo", которое я испытал, был StackOverflow исключение, которое не было выбрано системно для одного и того же ввода в случай нерегулярной рекурсивной реализации. Я могу обойти StackOverlow постепенно вызывает функцию с большим и большим значения... Я чувствую себя сумасшедшим:) Ответ: JVM требует сходиться во время запуска, тогда поведение является когерентным и систематическим

Это код:

final object Factorial {
  type Out = BigInt

  def calculateByRecursion(n: Int): Out = {
    require(n>0, "n must be positive")

    n match {
      case _ if n == 1 => return 1
      case _ => return n * calculateByRecursion(n-1)
    }
  }

  def calculateByForLoop(n: Int): Out = {
    require(n>0, "n must be positive")

    var accumulator: Out = 1
    for (i <- 1 to n)
      accumulator = i * accumulator
    accumulator
  }

  def calculateByWhileLoop(n: Int): Out = {
    require(n>0, "n must be positive")

    var accumulator: Out = 1
    var i = 1
    while (i <= n) {
      accumulator = i * accumulator
      i += 1
    }
    accumulator
  }

  def calculateByTailRecursion(n: Int): Out = {
    require(n>0, "n must be positive")

    @tailrec def fac(n: Int, acc: Out): Out = n match {
      case _ if n == 1 => acc
      case _ => fac(n-1, n * acc)
    }

    fac(n, 1)
  }

  def calculateByTailRecursionUpward(n: Int): Out = {
    require(n>0, "n must be positive")

    @tailrec def fac(i: Int, acc: Out): Out = n match {
      case _ if i == n => n * acc
      case _ => fac(i+1, i * acc)
    }

    fac(1, 1)
  }

  def comparePerformance(n: Int) {
    def showOutput[A](msg: String, data: (Long, A), showOutput:Boolean = false) =
      showOutput match {
        case true => printf("%s returned %s in %d ms\n", msg, data._2.toString, data._1)
        case false => printf("%s in %d ms\n", msg, data._1)
    }
    def measure[A](f:()=>A): (Long, A) = {
      val start = System.currentTimeMillis
      val o = f()
      (System.currentTimeMillis - start, o)
    }
    showOutput ("By for loop", measure(()=>calculateByForLoop(n)))
    showOutput ("By while loop", measure(()=>calculateByWhileLoop(n)))
    showOutput ("By non-tail recursion", measure(()=>calculateByRecursion(n)))
    showOutput ("By tail recursion", measure(()=>calculateByTailRecursion(n)))
    showOutput ("By tail recursion upward", measure(()=>calculateByTailRecursionUpward(n)))
  }
}

Далее следует какой-то вывод из консоли sbt (до реализации "while" ):

scala> example.Factorial.comparePerformance(10000)
By loop in 3 ns
By non-tail recursion in >>>>> StackOverflow!!!!!… see later!!!
........

scala> example.Factorial.comparePerformance(1000)
By loop in 3 ms
By non-tail recursion in 1 ms
By tail recursion in 4 ms

scala> example.Factorial.comparePerformance(5000)
By loop in 105 ms
By non-tail recursion in 27 ms
By tail recursion in 34 ms

scala> example.Factorial.comparePerformance(10000)
By loop in 236 ms
By non-tail recursion in 106 ms     >>>> Now works!!!
By tail recursion in 127 ms

scala> example.Factorial.comparePerformance(20000)
By loop in 977 ms
By non-tail recursion in 495 ms
By tail recursion in 564 ms

scala> example.Factorial.comparePerformance(30000)
By loop in 2285 ms
By non-tail recursion in 1183 ms
By tail recursion in 1281 ms

Далее следует какой-то вывод из консоли sbt (после реализации "while" ):

scala> example.Factorial.comparePerformance(10000)
By for loop in 252 ms
By while loop in 246 ms
By non-tail recursion in 130 ms
By tail recursion in 136 ns

scala> example.Factorial.comparePerformance(20000)
By for loop in 984 ms
By while loop in 1091 ms
By non-tail recursion in 508 ms
By tail recursion in 560 ms

Далее следует какой-то вывод из консоли sbt (после "восходящей" реализации хвостовой рекурсии) мир возвращается в нормальное состояние:

scala> example.Factorial.comparePerformance(10000)
By for loop in 259 ms
By while loop in 229 ms
By non-tail recursion in 114 ms
By tail recursion in 119 ms
By tail recursion upward in 105 ms

scala> example.Factorial.comparePerformance(20000)
By for loop in 1053 ms
By while loop in 957 ms
By non-tail recursion in 513 ms
By tail recursion in 565 ms
By tail recursion upward in 470 ms

Далее следует какой-то вывод из консоли sbt после фиксации BigInt-умножения в "циклах": мир абсолютно нормальный:

    scala> example.Factorial.comparePerformance(20000)
By for loop in 498 ms
By while loop in 502 ms
By non-tail recursion in 521 ms
By tail recursion in 611 ms
By tail recursion upward in 503 ms

Накладные расходы BigInt и глупое выполнение мной замаскировали ожидаемое поведение.

Спасибо, ребята,

PS: В конце я должен переименовать это сообщение в "Урок Lernt на BigInts"

Ответ 1

Для циклов на самом деле не совсем петли; они предназначены для понимания на диапазоне. Если вам действительно нужен цикл, вам нужно использовать while. (На самом деле, я думаю, что умножение BigInt здесь имеет большой вес, поэтому это не имеет значения. Но вы заметите, что вы умножаете Int s.)

Кроме того, вы смутили себя, используя BigInt. Чем больше ваш BigInt, тем медленнее умножение. Таким образом, ваша не-хвостовая петля подсчитывается, в то время как контур хвоста рекурсии заканчивается, что означает, что последний имеет более большие числа, чтобы размножаться.

Если вы исправите эти две проблемы, вы обнаружите, что здравомыслие восстанавливается: циклы и хвостовая рекурсия имеют одинаковую скорость, причем как регулярная рекурсия, так и for медленнее. (Регулярная рекурсия может быть не медленнее, если оптимизация JVM делает ее эквивалентной)

(Кроме того, исправление, вероятно, связано с тем, что JVM начинает встраивание и может либо сделать сам хвост-рекурсивный вызов, либо развернуть цикл достаточно далеко, чтобы вы больше не переполняли.)

Наконец, вы получаете плохие результаты за и время, потому что вы умножаетесь справа, а не слева, с небольшим числом. Оказывается, Java BigInt умножается быстрее с меньшим числом слева.

Ответ 2

Scala статические методы для factorial(n) (закодированы с помощью scala 2.12.x, java-8):

object Factorial {

  /*
   * For large N, it throws a stack overflow
   */
  def recursive(n:BigInt): BigInt = {
    if(n < 0) {
      throw new ArithmeticException
    } else if(n <= 1) {
      1
    } else {
      n * recursive(n - 1)
    }
  }

  /*
   * A tail recursive method is compiled to avoid stack overflow
   */
  @scala.annotation.tailrec
  def recursiveTail(n:BigInt, acc:BigInt = 1): BigInt = {
    if(n < 0) {
      throw new ArithmeticException
    } else if(n <= 1) {
      acc
    } else {
      recursiveTail(n - 1, n * acc)
    }
  }

  /*
   * A while loop
   */
  def loop(n:BigInt): BigInt = {
    if(n < 0) {
      throw new ArithmeticException
    } else if(n <= 1) {
      1
    } else {
      var acc = 1
      var idx = 1
      while(idx <= n) {
        acc = idx * acc
        idx += 1
      }
      acc
    }
  }

}

Технические характеристики:

class FactorialSpecs extends SpecHelper {

  private val smallInt = 10
  private val largeInt = 10000

  describe("Factorial.recursive") {
    it("return 1 for 0") {
      assert(Factorial.recursive(0) == 1)
    }
    it("return 1 for 1") {
      assert(Factorial.recursive(1) == 1)
    }
    it("return 2 for 2") {
      assert(Factorial.recursive(2) == 2)
    }
    it("returns a result, for small inputs") {
      assert(Factorial.recursive(smallInt) == 3628800)
    }
    it("throws StackOverflow for large inputs") {
      intercept[java.lang.StackOverflowError] {
        Factorial.recursive(Int.MaxValue)
      }
    }
  }

  describe("Factorial.recursiveTail") {
    it("return 1 for 0") {
      assert(Factorial.recursiveTail(0) == 1)
    }
    it("return 1 for 1") {
      assert(Factorial.recursiveTail(1) == 1)
    }
    it("return 2 for 2") {
      assert(Factorial.recursiveTail(2) == 2)
    }
    it("returns a result, for small inputs") {
      assert(Factorial.recursiveTail(smallInt) == 3628800)
    }
    it("returns a result, for large inputs") {
      assert(Factorial.recursiveTail(largeInt).isInstanceOf[BigInt])
    }
  }

  describe("Factorial.loop") {
    it("return 1 for 0") {
      assert(Factorial.loop(0) == 1)
    }
    it("return 1 for 1") {
      assert(Factorial.loop(1) == 1)
    }
    it("return 2 for 2") {
      assert(Factorial.loop(2) == 2)
    }
    it("returns a result, for small inputs") {
      assert(Factorial.loop(smallInt) == 3628800)
    }
    it("returns a result, for large inputs") {
      assert(Factorial.loop(largeInt).isInstanceOf[BigInt])
    }
  }
}

Ориентиры:

import org.scalameter.api._

class BenchmarkFactorials extends Bench.OfflineReport {

  val gen: Gen[Int] = Gen.range("N")(1, 1000, 100) // scalastyle:ignore

  performance of "Factorial" in {
    measure method "loop" in {
      using(gen) in {
        n => Factorial.loop(n)
      }
    }
    measure method "recursive" in {
      using(gen) in {
        n => Factorial.recursive(n)
      }
    }
    measure method "recursiveTail" in {
      using(gen) in {
        n => Factorial.recursiveTail(n)
      }
    }
  }

}

Результаты тестов (цикл намного быстрее):

[info] Test group: Factorial.loop
[info] - Factorial.loop.Test-9 measurements:
[info]   - at N -> 1: passed
[info]     (mean = 0.01 ms, ci = <0.00 ms, 0.02 ms>, significance = 1.0E-10)
[info]   - at N -> 101: passed
[info]     (mean = 0.01 ms, ci = <0.01 ms, 0.01 ms>, significance = 1.0E-10)
[info]   - at N -> 201: passed
[info]     (mean = 0.02 ms, ci = <0.02 ms, 0.02 ms>, significance = 1.0E-10)
[info]   - at N -> 301: passed
[info]     (mean = 0.03 ms, ci = <0.02 ms, 0.03 ms>, significance = 1.0E-10)
[info]   - at N -> 401: passed
[info]     (mean = 0.03 ms, ci = <0.03 ms, 0.04 ms>, significance = 1.0E-10)
[info]   - at N -> 501: passed
[info]     (mean = 0.04 ms, ci = <0.03 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 601: passed
[info]     (mean = 0.04 ms, ci = <0.04 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 701: passed
[info]     (mean = 0.05 ms, ci = <0.05 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 801: passed
[info]     (mean = 0.06 ms, ci = <0.05 ms, 0.06 ms>, significance = 1.0E-10)
[info]   - at N -> 901: passed
[info]     (mean = 0.06 ms, ci = <0.05 ms, 0.07 ms>, significance = 1.0E-10)

[info] Test group: Factorial.recursive
[info] - Factorial.recursive.Test-10 measurements:
[info]   - at N -> 1: passed
[info]     (mean = 0.00 ms, ci = <0.00 ms, 0.01 ms>, significance = 1.0E-10)
[info]   - at N -> 101: passed
[info]     (mean = 0.05 ms, ci = <0.01 ms, 0.09 ms>, significance = 1.0E-10)
[info]   - at N -> 201: passed
[info]     (mean = 0.03 ms, ci = <0.02 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 301: passed
[info]     (mean = 0.07 ms, ci = <0.00 ms, 0.13 ms>, significance = 1.0E-10)
[info]   - at N -> 401: passed
[info]     (mean = 0.09 ms, ci = <0.01 ms, 0.18 ms>, significance = 1.0E-10)
[info]   - at N -> 501: passed
[info]     (mean = 0.10 ms, ci = <0.03 ms, 0.17 ms>, significance = 1.0E-10)
[info]   - at N -> 601: passed
[info]     (mean = 0.11 ms, ci = <0.08 ms, 0.15 ms>, significance = 1.0E-10)
[info]   - at N -> 701: passed
[info]     (mean = 0.13 ms, ci = <0.11 ms, 0.14 ms>, significance = 1.0E-10)
[info]   - at N -> 801: passed
[info]     (mean = 0.16 ms, ci = <0.13 ms, 0.19 ms>, significance = 1.0E-10)
[info]   - at N -> 901: passed
[info]     (mean = 0.21 ms, ci = <0.15 ms, 0.27 ms>, significance = 1.0E-10)

[info] Test group: Factorial.recursiveTail
[info] - Factorial.recursiveTail.Test-11 measurements:
[info]   - at N -> 1: passed
[info]     (mean = 0.00 ms, ci = <0.00 ms, 0.01 ms>, significance = 1.0E-10)
[info]   - at N -> 101: passed
[info]     (mean = 0.04 ms, ci = <0.03 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 201: passed
[info]     (mean = 0.12 ms, ci = <0.05 ms, 0.20 ms>, significance = 1.0E-10)
[info]   - at N -> 301: passed
[info]     (mean = 0.16 ms, ci = <-0.03 ms, 0.34 ms>, significance = 1.0E-10)
[info]   - at N -> 401: passed
[info]     (mean = 0.12 ms, ci = <0.09 ms, 0.16 ms>, significance = 1.0E-10)
[info]   - at N -> 501: passed
[info]     (mean = 0.17 ms, ci = <0.15 ms, 0.19 ms>, significance = 1.0E-10)
[info]   - at N -> 601: passed
[info]     (mean = 0.23 ms, ci = <0.19 ms, 0.26 ms>, significance = 1.0E-10)
[info]   - at N -> 701: passed
[info]     (mean = 0.25 ms, ci = <0.18 ms, 0.32 ms>, significance = 1.0E-10)
[info]   - at N -> 801: passed
[info]     (mean = 0.28 ms, ci = <0.21 ms, 0.36 ms>, significance = 1.0E-10)
[info]   - at N -> 901: passed
[info]     (mean = 0.32 ms, ci = <0.17 ms, 0.46 ms>, significance = 1.0E-10)

Ответ 3

Я знаю, что все уже ответили на вопрос, но я подумал, что мог бы добавить одну оптимизацию: если вы преобразуете сопоставление с шаблоном в простые операторы if, это может ускорить хвостовую рекурсию.

final object Factorial {
  type Out = BigInt

  def calculateByRecursion(n: Int): Out = {
    require(n>0, "n must be positive")

    n match {
      case _ if n == 1 => return 1
      case _ => return n * calculateByRecursion(n-1)
    }
  }

  def calculateByForLoop(n: Int): Out = {
    require(n>0, "n must be positive")

    var accumulator: Out = 1
    for (i <- 1 to n)
      accumulator = i * accumulator
    accumulator
  }

  def calculateByWhileLoop(n: Int): Out = {
    require(n>0, "n must be positive")

    var acc: Out = 1
    var i = 1
    while (i <= n) {
      acc = i * acc
      i += 1
    }
    acc
  }

  def calculateByTailRecursion(n: Int): Out = {
    require(n>0, "n must be positive")

    @annotation.tailrec
    def fac(n: Int, acc: Out): Out = if (n==1) acc else fac(n-1, n*acc)

    fac(n, 1)
  }

  def calculateByTailRecursionUpward(n: Int): Out = {
    require(n>0, "n must be positive")

    @annotation.tailrec
    def fac(i: Int, acc: Out): Out = if (i == n) n*acc else fac(i+1, i*acc)

    fac(1, 1)
  }

  def attempt(f: ()=>Unit): Boolean = {
    try {
        f()
        true
    } catch {
        case _: Throwable =>
            println(" <<<<< Failed...")
            false
    }
  }

  def comparePerformance(n: Int) {
    def showOutput[A](msg: String, data: (Long, A), showOutput:Boolean = true) =
      showOutput match {
        case true =>
            val res = data._2.toString
            val pref = res.substring(0,5)
            val midd = res.substring((res.length-5)/ 2, (res.length-5)/ 2 + 10)
            val suff = res.substring(res.length-5)
            printf("%s returned %s in %d ms\n", msg, s"$pref...$midd...$suff" , data._1)
        case false => 
            printf("%s in %d ms\n", msg, data._1)
    }
    def measure[A](f:()=>A): (Long, A) = {
      val start = System.currentTimeMillis
      val o = f()
      (System.currentTimeMillis - start, o)
    }
    attempt(() => showOutput ("By for loop", measure(()=>calculateByForLoop(n))))
    attempt(() => showOutput ("By while loop", measure(()=>calculateByWhileLoop(n))))
    attempt(() => showOutput ("By non-tail recursion", measure(()=>calculateByRecursion(n))))
    attempt(() => showOutput ("By tail recursion", measure(()=>calculateByTailRecursion(n))))
    attempt(() => showOutput ("By tail recursion upward", measure(()=>calculateByTailRecursionUpward(n))))
  }
}

Мои результаты:

scala> Factorial.comparePerformance(20000)
By for loop returned 18192...5708616582...00000 in 179 ms
By while loop returned 18192...5708616582...00000 in 159 ms
By non-tail recursion <<<<< Failed...
By tail recursion returned 18192...5708616582...00000 in 169 ms
By tail recursion upward returned 18192...5708616582...00000 in 174 ms
By for loop returned 18192...5708616582...00000 in 212 ms
By while loop returned 18192...5708616582...00000 in 156 ms
By non-tail recursion returned 18192...5708616582...00000 in 155 ms
By tail recursion returned 18192...5708616582...00000 in 166 ms
By tail recursion upward returned 18192...5708616582...00000 in 137 ms
scala> Factorial.comparePerformance(200000)
By for loop returned 14202...0169293868...00000 in 17467 ms
By while loop returned 14202...0169293868...00000 in 17303 ms
By non-tail recursion <<<<< Failed...
By tail recursion returned 14202...0169293868...00000 in 18477 ms
By tail recursion upward returned 14202...0169293868...00000 in 17188 ms