在Scala中有没有一种通用的方法来记忆?

我想记住这个:

def fib(n: Int) = if(n <= 1) 1 else fib(n-1) + fib(n-2) println(fib(100)) // times out 

所以我写了这个,这令人惊讶的编译和工作(我感到惊讶,因为fib引用本身的声明):

 case class Memo[A,B](f: A => B) extends (A => B) { private val cache = mutable.Map.empty[A, B] def apply(x: A) = cache getOrElseUpdate (x, f(x)) } val fib: Memo[Int, BigInt] = Memo { case 0 => 0 case 1 => 1 case n => fib(n-1) + fib(n-2) } println(fib(100)) // prints 100th fibonacci number instantly 

但是当我试图在一个def声明fib时,我得到一个编译器错误:

 def foo(n: Int) = { val fib: Memo[Int, BigInt] = Memo { case 0 => 0 case 1 => 1 case n => fib(n-1) + fib(n-2) } fib(n) } 

上面未能编译error: forward reference extends over definition of value fib case n => fib(n-1) + fib(n-2)

为什么在一个def中声明val fib失败,但在类/对象范围外?

为了澄清,为什么我可能想在def范围内声明recursionmemoized函数 – 这里是我对子集sum问题的解决scheme:

 /** * Subset sum algorithm - can we achieve sum t using elements from s? * * @param s set of integers * @param t target * @return true iff there exists a subset of s that sums to t */ def subsetSum(s: Seq[Int], t: Int): Boolean = { val max = s.scanLeft(0)((sum, i) => (sum + i) max sum) //max(i) = largest sum achievable from first i elements val min = s.scanLeft(0)((sum, i) => (sum + i) min sum) //min(i) = smallest sum achievable from first i elements val dp: Memo[(Int, Int), Boolean] = Memo { // dp(i,x) = can we achieve x using the first i elements? case (_, 0) => true // 0 can always be achieved using empty set case (0, _) => false // if empty set, non-zero cannot be achieved case (i, x) if min(i) <= x && x <= max(i) => dp(i-1, x - s(i-1)) || dp(i-1, x) // try with/without s(i-1) case _ => false // outside range otherwise } dp(s.length, t) } 

我发现一个更好的方式来记忆使用Scala:

 def memoize[I, O](f: I => O): I => O = new mutable.HashMap[I, O]() { override def apply(key: I) = getOrElseUpdate(key, f(key)) } 

现在你可以写斐波那契如下:

 lazy val fib: Int => BigInt = memoize { case 0 => 0 case 1 => 1 case n => fib(n-1) + fib(n-2) } 

这里有一个有多个参数(select函数):

 lazy val c: ((Int, Int)) => BigInt = memoize { case (_, 0) => 1 case (n, r) if r > n/2 => c(n, n - r) case (n, r) => c(n - 1, r - 1) + c(n - 1, r) } 

这里是子集和问题:

 // is there a subset of s which has sum = t def isSubsetSumAchievable(s: Vector[Int], t: Int) = { // f is (i, j) => Boolean ie can the first i elements of s add up to j lazy val f: ((Int, Int)) => Boolean = memoize { case (_, 0) => true // 0 can always be achieved using empty list case (0, _) => false // we can never achieve non-zero if we have empty list case (i, j) => val k = i - 1 // try the kth element f(k, j - s(k)) || f(k, j) } f(s.length, t) } 

编辑:如下所述,这是一个线程安全的版本

 def memoize[I, O](f: I => O): I => O = new mutable.HashMap[I, O]() {self => override def apply(key: I) = self.synchronized(getOrElseUpdate(key, f(key))) } 

Class / trait level val编译为一个方法和一个私有variables的组合。 因此允许recursion定义。

另一方面,本地值只是常规variables,因此不允许recursion定义。

顺便说一句,即使你定义的def工作,它也不会做你期望的。 在每个foo调用中,都会创build一个新的函数对象fib ,并且它将拥有自己的后台映射。 你应该做的是这个(如果你真的想要一个def作为你的公共接口):

 private val fib: Memo[Int, BigInt] = Memo { case 0 => 0 case 1 => 1 case n => fib(n-1) + fib(n-2) } def foo(n: Int) = { fib(n) } 

斯卡拉斯有一个解决scheme,为什么不重用呢?

 import scalaz.Memo lazy val fib: Int => BigInt = Memo.mutableHashMapMemo { case 0 => 0 case 1 => 1 case n => fib(n-2) + fib(n-1) } 

你可以阅读更多关于Scalaz的memoization 。

可变的HashMap不是线程安全的。 另外,为基本条件分别定义case语句似乎不必要的特殊处理,而Map可以用初始值加载并传递给Memoizer。 以下将是Memoizer的签名,它接受一个备忘录(不可变Map)和公式,并返回一个recursion函数。

Memoizer看起来像

 def memoize[I,O](memo: Map[I, O], formula: (I => O, I) => O): I => O 

现在给出以下斐波那契公式,

 def fib(f: Int => Int, n: Int) = f(n-1) + f(n-2) 

斐波那契与Memoizer可以被定义为

 val fibonacci = memoize( Map(0 -> 0, 1 -> 1), fib) 

上下文不可知的通用Memoizer被定义为

  def memoize[I, O](map: Map[I, O], formula: (I => O, I) => O): I => O = { var memo = map def recur(n: I): O = { if( memo contains n) { memo(n) } else { val result = formula(recur, n) memo += (n -> result) result } } recur } 

同样,对于阶乘,一个公式是

 def fac(f: Int => Int, n: Int): Int = n * f(n-1) 

和Memoizer的因子是

 val factorial = memoize( Map(0 -> 1, 1 -> 1), fac) 

灵感:Memoization,第4章的道格拉斯·克罗克福德的好作品