在Haskell中记忆?

任何关于如何在Haskell中有效地解决以下函数的指针,对于大数目(n > 108)

 f(n) = max(n, f(n/2) + f(n/3) + f(n/4)) 

我已经看到了Haskell中的memoization的例子来解决斐波纳契数字,其中包括计算(懒洋洋地)所有的斐波纳契数字到所需的n。 但在这种情况下,对于给定的n,我们只需要计算非常less的中间结果。

谢谢

我们可以通过build立一个可以在次线性时间索引的结构来非常有效地做到这一点。

但首先,

 {-# LANGUAGE BangPatterns #-} import Data.Function (fix) 

让我们定义f,但是使用“open recursion”而不是直接调用它自己。

 f :: (Int -> Int) -> Int -> Int f mf 0 = 0 f mf n = max n $ mf (n `div` 2) + mf (n `div` 3) + mf (n `div` 4) 

你可以通过使用fix f来获得一个unmemoized fix f

这可以让你testing一下f是通过调用f来表示f的小值,例如: fix f 123 = 144

我们可以通过定义:

 f_list :: [Int] f_list = map (f faster_f) [0..] faster_f :: Int -> Int faster_f n = f_list !! n 

这样做performance得很好,并用记忆中间结果的东西取代了要花费O(n ^ 3)时间的事情。

但是,只需要线性的时间索引来findmf的备忘答案。 这意味着结果如下所示:

 *Main Data.List> faster_f 123801 248604 

是可以忍受的,但结果并不比这个好得多。 我们可以做得更好!

首先,我们定义一棵无限树:

 data Tree a = Tree (Tree a) a (Tree a) instance Functor Tree where fmap f (Tree lmr) = Tree (fmap fl) (fm) (fmap fr) 

然后我们将定义一个索引的方法,所以我们可以在O(log n)时间内find索引为n的节点:

 index :: Tree a -> Int -> a index (Tree _ m _) 0 = m index (Tree l _ r) n = case (n - 1) `divMod` 2 of (q,0) -> index lq (q,1) -> index rq 

…我们可能会发现一棵充满自然数的树是方便的,所以我们不必摆弄那些指数:

 nats :: Tree Int nats = go 0 1 where go !n !s = Tree (go l s') n (go r s') where l = n + s r = l + s s' = s * 2 

既然我们可以索引,你可以把一棵树转换成一个列表:

 toList :: Tree a -> [a] toList as = map (index as) [0..] 

你可以通过validationtoList nats给你[0..]来检查工作

现在,

 f_tree :: Tree Int f_tree = fmap (f fastest_f) nats fastest_f :: Int -> Int fastest_f = index f_tree 

就像上面的列表一样工作,但是不用花费时间来查找每个节点,而是可以在对数时间内追踪它。

结果是相当快:

 *Main> fastest_f 12380192300 67652175206 *Main> fastest_f 12793129379123 120695231674999 

事实上,它是如此之快,以至于你可以通过上面的Integer代替Int ,几乎可以瞬间获得可笑的大回答

 *Main> fastest_f' 1230891823091823018203123 93721573993600178112200489 *Main> fastest_f' 12308918230918230182031231231293810923 11097012733777002208302545289166620866358 

爱德华的答案是这样一个奇妙的gem,我已经复制它,并提供memoListmemoTree组合器的实现memoList memoTree一个函数开放recursion的forms。

 {-# LANGUAGE BangPatterns #-} import Data.Function (fix) f :: (Integer -> Integer) -> Integer -> Integer f mf 0 = 0 f mf n = max n $ mf (div n 2) + mf (div n 3) + mf (div n 4) -- Memoizing using a list -- The memoizing functionality depends on this being in eta reduced form! memoList :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer memoList f = memoList_f where memoList_f = (memo !!) . fromInteger memo = map (f memoList_f) [0..] faster_f :: Integer -> Integer faster_f = memoList f -- Memoizing using a tree data Tree a = Tree (Tree a) a (Tree a) instance Functor Tree where fmap f (Tree lmr) = Tree (fmap fl) (fm) (fmap fr) index :: Tree a -> Integer -> a index (Tree _ m _) 0 = m index (Tree l _ r) n = case (n - 1) `divMod` 2 of (q,0) -> index lq (q,1) -> index rq nats :: Tree Integer nats = go 0 1 where go !n !s = Tree (go l s') n (go r s') where l = n + s r = l + s s' = s * 2 toList :: Tree a -> [a] toList as = map (index as) [0..] -- The memoizing functionality depends on this being in eta reduced form! memoTree :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer memoTree f = memoTree_f where memoTree_f = index memo memo = fmap (f memoTree_f) nats fastest_f :: Integer -> Integer fastest_f = memoTree f 

不是最有效的方法,但要记住:

 f = 0 : [ gn | n <- [1..] ] where gn = max n $ f!!(n `div` 2) + f!!(n `div` 3) + f!!(n `div` 4) 

当请求f !! 144 f !! 144 ,检查f !! 143 f !! 143存在,但其确切的价值是不计算的。 它仍然是一个未知的计算结果。 计算出的唯一精确值就是需要的值。

所以最初,就计算了多less而言,程序一无所知。

 f = .... 

当我们提出要求f !! 12 f !! 12 ,它开始做一些模式匹配:

 f = 0 : g 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ... 

现在开始计算

 f !! 12 = g 12 = max 12 $ f!!6 + f!!4 + f!!3 

这recursion地对f做了另一个需求,所以我们计算

 f !! 6 = g 6 = max 6 $ f !! 3 + f !! 2 + f !! 1 f !! 3 = g 3 = max 3 $ f !! 1 + f !! 1 + f !! 0 f !! 1 = g 1 = max 1 $ f !! 0 + f !! 0 + f !! 0 f !! 0 = 0 

现在我们可以涓涓回来一些

 f !! 1 = g 1 = max 1 $ 0 + 0 + 0 = 1 

这意味着程序现在知道:

 f = 0 : 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ... 

继续滴下:

 f !! 3 = g 3 = max 3 $ 1 + 1 + 0 = 3 

这意味着程序现在知道:

 f = 0 : 1 : g 2 : 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ... 

现在我们继续计算f!!6

 f !! 6 = g 6 = max 6 $ 3 + f !! 2 + 1 f !! 2 = g 2 = max 2 $ f !! 1 + f !! 0 + f !! 0 = max 2 $ 1 + 0 + 0 = 2 f !! 6 = g 6 = max 6 $ 3 + 2 + 1 = 6 

这意味着程序现在知道:

 f = 0 : 1 : 2 : 3 : g 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ... 

现在我们继续计算f!!12

 f !! 12 = g 12 = max 12 $ 6 + f!!4 + 3 f !! 4 = g 4 = max 4 $ f !! 2 + f !! 1 + f !! 1 = max 4 $ 2 + 1 + 1 = 4 f !! 12 = g 12 = max 12 $ 6 + 4 + 3 = 13 

这意味着程序现在知道:

 f = 0 : 1 : 2 : 3 : 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : 13 : ... 

所以计算是相当懒散地完成的。 该程序知道一些f !! 8f !! 8 f !! 8存在,它等于g 8 ,但不知道g 8是什么。

这是爱德华·科梅特(Edward Kmett)的出色答案的附录。

当我尝试他的代码时, natsindex的定义看起来很神秘,所以我编写了一个我觉得更容易理解的替代版本。

我用indexnats来定义indexnats

index' tn[1..]范围内定义。 (回想一下index t是在[0..]范围内定义的[0..] 。它通过将n视为一串比特,并通过反向读取比特来search树。 如果这个位是1 ,它就是右边的分支。 如果该位是0 ,则需要左手分支。 它到达最后一位时(它必须是1 )停止。

 index' (Tree lmr) 1 = m index' (Tree lmr) n = case n `divMod` 2 of (n', 0) -> index' ln' (n', 1) -> index' rn' 

正如nats被定义为index以便index nats n == n始终为真, nats'index'定义。

 nats' = Tree l 1 r where l = fmap (\n -> n*2) nats' r = fmap (\n -> n*2 + 1) nats' nats' = Tree l 1 r 

现在, natsindex只不过是nats'index'而已,

 index tn = index' t (n+1) nats = fmap (\n -> n-1) nats' 

正如Edward Kmett的回答所述,为了加快速度,您需要caching昂贵的计算并能够快速访问它们。

为了保持这个函数不是一元的,用一个合适的方式来build立一个无限的懒树(如前面的post所示)来实现这个目标。 如果放弃函数的非单调性质,可以将Haskell中的标准关联容器与“状态”单元(如State或ST)结合使用。

虽然主要缺点是你得到一个非单子函数,你不必自己索引结构,只能使用关联容器的标准实现。

要做到这一点,首先需要重新编写函数来接受任何types的monad:

 fm :: (Integral a, Monad m) => (a -> ma) -> a -> ma fm _ 0 = return 0 fm recf n = do recs <- mapM recf $ div n <$> [2, 3, 4] return $ max n (sum recs) 

对于你的testing,你仍然可以使用Data.Function.fix定义一个没有记忆的函数,尽pipe它有点冗长:

 noMemoF :: (Integral n) => n -> n noMemoF = runIdentity . fix fm 

然后,您可以将State monad与Data.Map结合使用来加快速度:

 import qualified Data.Map.Strict as MS withMemoStMap :: (Integral n) => n -> n withMemoStMap n = evalState (fm recF n) MS.empty where recF i = do v <- MS.lookup i <$> get case v of Just v' -> return v' Nothing -> do v' <- fm recF i modify $ MS.insert iv' return v' 

只需稍作更改,您就可以使代码适应Data.HashMap:

 import qualified Data.HashMap.Strict as HMS withMemoStHMap :: (Integral n, Hashable n) => n -> n withMemoStHMap n = evalState (fm recF n) HMS.empty where recF i = do v <- HMS.lookup i <$> get case v of Just v' -> return v' Nothing -> do v' <- fm recF i modify $ HMS.insert iv' return v' 

除了持久的数据结构之外,您还可以尝试将可变数据结构(如Data.HashTable)与ST monad结合使用:

 import qualified Data.HashTable.ST.Linear as MHM withMemoMutMap :: (Integral n, Hashable n) => n -> n withMemoMutMap n = runST $ do ht <- MHM.new recF ht n where recF ht i = do k <- MHM.lookup ht i case k of Just k' -> return k' Nothing -> do k' <- fm (recF ht) i MHM.insert ht ik' return k' 

与没有任何memoization的实现相比,这些实现中的任何一个都可以让您在巨大的input下以微秒的速度得到结果,而不必等待几秒钟。

使用Criterion作为基准,我可以观察到Data.HashMap的实现实际上比Data.Map和Data.HashTable(其时间非常相似)稍好(约20%)。

我发现基准的结果有点令人惊讶。 我最初的感觉是HashTable会超越HashMap的实现,因为它是可变的。 在最后的实现中可能会隐藏一些性能缺陷。

几年后,我看着这个,意识到有一个简单的方法来使用zipWith和一个辅助函数在线性时间内对此进行zipWith

 dilate :: Int -> [x] -> [x] dilate n xs = replicate n =<< xs 

dilate有方便的财产, dilate n xs !! i == xs !! div in dilate n xs !! i == xs !! div in dilate n xs !! i == xs !! div in

所以,假设我们给了f(0),这简化了计算

 fs = f0 : zipWith max [1..] (tail $ fs#/2 .+. fs#/3 .+. fs#/4) where (.+.) = zipWith (+) infixl 6 .+. (#/) = flip dilate infixl 7 #/ 

看起来很像我们原来的问题描述,并给出一个线性解决scheme( sum $ take n fs将采取O(n))。

一个没有索引的解决scheme,不是基于Edward KMETT的。

f(n/2)f(n/4)之间共享f(n/2)f(n/4) f(n/2)f(n/4)共享f(3) )。 通过将它们保存为父variables中的单个variables,子树的计算将执行一次。

 data Tree a = Node {datum :: a, child2 :: Tree a, child3 :: Tree a} f :: Int -> Int fn = datum root where root = f' n Nothing Nothing -- Pass in the arg -- and this node's lifted children (if any). f' :: Integral a => a -> Maybe (Tree a) -> Maybe (Tree a)-> a f' 0 _ _ = leaf where leaf = Node 0 leaf leaf f' n m2 m3 = Node d c2 c3 where d = if n < 12 then n else max n (d2 + d3 + d4) [n2,n3,n4,n6] = map (n `div`) [2,3,4,6] [d2,d3,d4,d6] = map datum [c2,c3,c4,c6] c2 = case m2 of -- Check for a passed-in subtree before recursing. Just c2' -> c2' Nothing -> f' n2 Nothing (Just c6) c3 = case m3 of Just c3' -> c3' Nothing -> f' n3 (Just c6) Nothing c4 = child2 c2 c6 = f' n6 Nothing Nothing main = print (f 123801) -- Should print 248604. 

该代码不容易扩展到一般的memoization函数(至less,我不知道该怎么做),你真的必须考虑如何子问题重叠,但策略应该适用于一般的多个非整数参数。 (我认为它的两个string参数。)

备忘录在每次计算后被丢弃。 (同样,我在考虑两个string参数。)

我不知道这是否比其他答案更有效。 每个查找在技术上只有一到两个步骤(“看看你的孩子或你的孩子的孩子”),但可能会有很多额外的内存使用。

编辑:这个解决scheme还不正确。 分享不完整。

编辑:现在应该正确地分享子女,但是我意识到这个问题有很多不平凡的分享: n/2/2/2 2/2/2和n/3/3可能是一样的。 这个问题不适合我的策略。

爱德华·克梅特(Edward Kmett)的另一个补充答案是一个独立的例子:

 data NatTrie v = NatTrie (NatTrie v) v (NatTrie v) memo1 arg_to_index index_to_arg f = (\n -> index nats (arg_to_index n)) where nats = go 0 1 go is = NatTrie (go (i+s) s') (f (index_to_arg i)) (go (i+s') s') where s' = 2*s index (NatTrie lvr) i | i < 0 = f (index_to_arg i) | i == 0 = v | otherwise = case (i-1) `divMod` 2 of (i',0) -> index li' (i',1) -> index ri' memoNat = memo1 id id 

用下面的方法来记忆一个整数arg(如斐波纳契)的函数:

 fib = memoNat f where f 0 = 0 f 1 = 1 fn = fib (n-1) + fib (n-2) 

只有非负参数的值才会被caching。

为了也caching消极参数的值,使用memoInt ,定义如下:

 memoInt = memo1 arg_to_index index_to_arg where arg_to_index n | n < 0 = -2*n | otherwise = 2*n + 1 index_to_arg i = case i `divMod` 2 of (n,0) -> -n (n,1) -> n 

要使用两个整数参数caching函数的值,使用memoIntInt ,定义如下:

 memoIntInt f = memoInt (\n -> memoInt (fn))