问与答 范围-在Scala中有通用的记忆方式吗?

griffith · 2020-02-22 17:30:09 · 热度: 7

我想记住这一点:

def fib(n: Int) = if(n <= 1) 1 else fib(n-1) + fib(n-2)
println(fib(100)) // times out

因此,我编写了此文件,并完成了令人惊讶的编译和工作(我很惊讶,因为fib在其声明中引用了自己):

case class Memo[A,B](f: A => B) extends (A => B) {
  private val cache = mutable.Map.empty[A, B]
  def apply(x: A) = cache getOrElseUpdate (x, f(x))
}

val fib: Memo[Int, BigInt] = Memo {
  case 0 => 0
  case 1 => 1
  case n => fib(n-1) + fib(n-2) 
}

println(fib(100))     // prints 100th fibonacci number instantly

但是,当我尝试在val fib内声明fib时,出现编译器错误:

def foo(n: Int) = {
  val fib: Memo[Int, BigInt] = Memo {
    case 0 => 0
    case 1 => 1
    case n => fib(n-1) + fib(n-2) 
  }
  fib(n)
} 

以上未能编译val fib

为什么在def内部声明val fib失败,但在类/对象范围之外进行声明呢?

为了澄清,为什么我可能想在def范围内声明递归的备忘函数-这是我对子集和问题的解决方案:

/**
   * Subset sum algorithm - can we achieve sum t using elements from s?
   *
   * @param s set of integers
   * @param t target
   * @return true iff there exists a subset of s that sums to t
   */
  def subsetSum(s: Seq[Int], t: Int): Boolean = {
    val max = s.scanLeft(0)((sum, i) => (sum + i) max sum)  //max(i) =  largest sum achievable from first i elements
    val min = s.scanLeft(0)((sum, i) => (sum + i) min sum)  //min(i) = smallest sum achievable from first i elements

    val dp: Memo[(Int, Int), Boolean] = Memo {         // dp(i,x) = can we achieve x using the first i elements?
      case (_, 0) => true        // 0 can always be achieved using empty set
      case (0, _) => false       // if empty set, non-zero cannot be achieved
      case (i, x) if min(i) <= x && x <= max(i) => dp(i-1, x - s(i-1)) || dp(i-1, x)  // try with/without s(i-1)
      case _ => false            // outside range otherwise
    }

    dp(s.length, t)
  }

猜你喜欢:
共收到 4 条回复
tapasvi #1 · 2020-02-22 17:30:09

我找到了一种更好的使用Scala进行记忆的方法:

def memoize[I, O](f: I => O): I => O = new mutable.HashMap[I, O]() {
  override def apply(key: I) = getOrElseUpdate(key, f(key))
}

现在您可以编写斐波那契,如下所示:

lazy val fib: Int => BigInt = memoize {
  case 0 => 0
  case 1 => 1
  case n => fib(n-1) + fib(n-2)
}

这是一个带有多个参数(choose函数)的参数:

lazy val c: ((Int, Int)) => BigInt = memoize {
  case (_, 0) => 1
  case (n, r) if r > n/2 => c(n, n - r)
  case (n, r) => c(n - 1, r - 1) + c(n - 1, r)
}

这是子集总和问题:

// is there a subset of s which has sum = t
def isSubsetSumAchievable(s: Vector[Int], t: Int) = {
  // f is (i, j) => Boolean i.e. can the first i elements of s add up to j
  lazy val f: ((Int, Int)) => Boolean = memoize {
    case (_, 0) => true        // 0 can always be achieved using empty list
    case (0, _) => false       // we can never achieve non-zero if we have empty list
    case (i, j) => 
      val k = i - 1            // try the kth element
      f(k, j - s(k)) || f(k, j)
  }
  f(s.length, t)
}

编辑:如下所述,这是线程安全的版本

def memoize[I, O](f: I => O): I => O = new mutable.HashMap[I, O]() {self =>
  override def apply(key: I) = self.synchronized(getOrElseUpdate(key, f(key)))
}
jonah #2 · 2020-02-22 17:30:10

类/特征级别def编译为方法和私有变量的组合。 因此,允许递归定义。

另一方面,本地defs只是常规变量,因此不允许递归定义。

顺便说一句,即使您定义的def有效,它也无法满足您的期望。 每次调用foo时,都会创建一个新的函数对象fib,它将具有自己的支持映射。 相反,您应该做的是此操作(如果您确实希望将def作为您的公共接口):

private val fib: Memo[Int, BigInt] = Memo {
  case 0 => 0
  case 1 => 1
  case n => fib(n-1) + fib(n-2) 
}

def foo(n: Int) = {
  fib(n)
} 
sunil #3 · 2020-02-22 17:30:11

Scalaz有一个解决方案,为什么不重复使用呢?

import scalaz.Memo
lazy val fib: Int => BigInt = Memo.mutableHashMapMemo {
  case 0 => 0
  case 1 => 1
  case n => fib(n-2) + fib(n-1)
}

您可以在Scalaz中阅读有关备注的更多信息。

jayden #4 · 2020-02-22 17:30:12

可变的HashMap不是线程安全的。 同样为基本条件单独定义case语句似乎是不必要的特殊处理,而Map可以加载初始值并传递给Memoizer。 紧随其后的是Memoizer的签名,它在其中接受备忘录(不可变地图)和公式并返回递归函数。

备忘录看起来像

def memoize[I,O](memo: Map[I, O], formula: (I => O, I) => O): I => O

现在给出以下斐波那契公式,

def fib(f: Int => Int, n: Int) = f(n-1) + f(n-2)

带有备忘录的斐波那契可以定义为

val fibonacci = memoize( Map(0 -> 0, 1 -> 1), fib)

上下文无关的通用备忘录定义为

    def memoize[I, O](map: Map[I, O], formula: (I => O, I) => O): I => O = {
        var memo = map
        def recur(n: I): O = {
          if( memo contains n) {
            memo(n) 
          } else {
            val result = formula(recur, n)
            memo += (n -> result)
            result
          }
        }
        recur
      }

同样,对于阶乘,公式为

def fac(f: Int => Int, n: Int): Int = n * f(n-1)

和Memoizer的阶乘是

val factorial = memoize( Map(0 -> 1, 1 -> 1), fac)

灵感:回忆,道格拉斯·克罗克福德(Java)的第4章

需要 登录 后方可回复, 如果你还没有账号请点击这里 注册