Python是否优化尾递归?

我有以下一段代码失败,出现以下错误:

RuntimeError:超过最大递归深度

我试图重写这个以允许尾递归优化(TCO)。 我相信如果TCO发生,这个代码应该是成功的。

def trisum(n, csum): if n == 0: return csum else: return trisum(n - 1, csum + n) print(trisum(1000, 0)) 

我应该得出结论,Python不会做任何类型的TCO,或者我只是需要以不同的方式定义它?

不,也不会,因为Guido喜欢能够有适当的回溯

http://neopythonic.blogspot.com.au/2009/04/tail-recursion-elimination.html

http://neopythonic.blogspot.com.au/2009/04/final-words-on-tail-calls.html

您可以使用像这样的转换手动消除递归

 >>> def trisum(n, csum): ... while True: # change recursion to a while loop ... if n == 0: ... return csum ... n, csum = n - 1, csum + n # update parameters instead of tail recursion >>> trisum(1000,0) 500500 

编辑(2015-07-02): 随着时间的推移,我的回答变得相当流行,因为它最初是一个比其他任何东西都更多的链接,所以我决定花点时间重写一遍(但是,最初的答案可以被发现在最后)。

编辑(2015-07-12):我终于发布了一个执行tail-call优化的模块(处理尾递归和延续传递样式): https : //github.com/baruchel/tco

在Python中优化尾递归

人们经常声称,尾递归不适合pythonic编码方式,不应该关心如何将其嵌入到循环中。 我不想争论这个观点, 有时候,我喜欢尝试或实现新的想法作为尾递归函数,而不是循环出于各种原因(侧重于想法而不是过程,在屏幕上同时具有20个短的函数,而不仅仅是三个“pythonic”功能,在交互式会话中工作,而不是编辑我的代码等)。

在Python中优化尾递归其实很简单。 虽然说这是不可能的或非常棘手的,但我认为可以通过优雅,简短和一般的解决办法来实现。 我甚至认为这些解决方案中的大多数不使用Python功能。 清理lambda表达式与非常标准的循环一起工作,可以实现快速,高效且完全可用的工具来实现尾递归优化。

为了个人的方便,我写了一个小模块,通过两种不同的方式来实现这样的优化。 我想在这里讨论我的两个主要功能。

干净的方法:修改Y组合器

Y组合器是众所周知的; 它允许以递归方式使用lambda函数,但它本身不允许在循环中嵌入递归调用。 Lambda微积分本身不能做这样的事情。 然而,Y组合器稍微改变可以保护递归调用实际评估。 因此可以延迟评估。

这是Y组合的着名表达式:

 lambda f: (lambda x: x(x))(lambda y: f(lambda *args: y(y)(*args))) 

随着一个非常微小的变化,我可以得到:

 lambda f: (lambda x: x(x))(lambda y: f(lambda *args: lambda: y(y)(*args))) 

函数f现在不是调用它自己,而是返回一个执行相同调用的函数,但是由于它返回,所以可以稍后从外部进行评估。

我的代码是:

 def bet(func): b = (lambda f: (lambda x: x(x))(lambda y: f(lambda *args: lambda: y(y)(*args))))(func) def wrapper(*args): out = b(*args) while callable(out): out = out() return out return wrapper 

该功能可以按以下方式使用; 这里有两个因子和斐波那契尾递归版本的例子:

 >>> from recursion import * >>> fac = bet( lambda f: lambda n, a: a if not n else f(n-1,a*n) ) >>> fac(5,1) 120 >>> fibo = bet( lambda f: lambda n,p,q: p if not n else f(n-1,q,p+q) ) >>> fibo(10,0,1) 55 

显然,递归深度不再是问题:

 >>> bet( lambda f: lambda n: 42 if not n else f(n-1) )(50000) 42 

这当然是该功能的唯一真正目的。

只有一件事情不能用这个优化来完成:它不能用于对另一个函数进行求值的尾递归函数(这是因为可调用的返回对象全部作为进一步的递归调用进行处理,没有区别)。 由于我通常不需要这样的功能,所以我对上面的代码非常满意。 然而,为了提供一个更通用的模块,我想了一下,以便找到解决这个问题的一些办法(见下一节)。

关于这个过程的速度(这不是真正的问题),它恰好是相当好的; 尾递归函数甚至比使用更简单的表达式的下面的代码快得多:

 def bet1(func): def wrapper(*args): out = func(lambda *x: lambda: x)(*args) while callable(out): out = func(lambda *x: lambda: x)(*out()) return out return wrapper 

我认为,评估一个表达式,甚至是复杂的,比评估几个简单的表达式要快得多,在第二个版本中就是这种情况。 我没有把这个新功能放在我的模块中,我也没有看到可以使用它的情况,而不是“官方”的情况。

继续传递风格与例外

这是一个更一般的功能; 它能够处理所有的尾递归函数,包括那些返回其他函数的函数。 递归调用通过使用异常从其他返回值中识别。 这个解决方案比前一个更慢, 一个更快的代码可能可以通过在主循环中使用一些特殊值作为“标志”来编写,但我不喜欢使用特殊值或内部关键字的想法。 使用异常有一些有趣的解释:如果Python不喜欢尾递归调用,当发生尾递归调用时应该引发异常,pythonic方法将捕获异常以找到一些干净解决方案,这实际上是发生在这里…

 class _RecursiveCall(Exception): def __init__(self, *args): self.args = args def _recursiveCallback(*args): raise _RecursiveCall(*args) def bet0(func): def wrapper(*args): while True: try: return func(_recursiveCallback)(*args) except _RecursiveCall as e: args = e.args return wrapper 

现在所有的功能都可以使用。 在下面的例子中, f(n)的任何正值, f(n)被评估为标识函数:

 >>> f = bet0( lambda f: lambda n: (lambda x: x) if not n else f(n-1) ) >>> f(5)(42) 42 

当然,有人可能会认为,例外情况并不是意图用于意图重定向翻译(作为一种goto语句或可能是一种延续传球方式),这是我不得不承认的。 但是,我觉得有趣的是,使用单行作为return语句的try的想法是:我们尝试返回一些东西(正常行为),但是由于递归调用发生(异常),我们不能这样做。

初步答复(2013-08-29)

我写了一个非常小的插件来处理尾递归。 您可以在我的解释中找到它: https : //groups.google.com/forum/?hl=fr#!topic/comp.lang.python/dIsnJ2BoBKs

它可以在另一个函数中嵌入一个用尾递归样式编写的lambda函数,将其评估为一个循环。

在我看来,这个小函数中最有趣的特性是函数不依赖于某些脏编程,而仅仅依赖于lambda演算:当插入另一个lambda函数时,函数的行为被改变为另一个函数看起来非常像Y-combinator。

问候。

Guido这个词在http://neopythonic.blogspot.co.uk/2009/04/tail-recursion-elimination.html

我最近在我的Python历史博客上发表了一篇关于Python功能特性的文章。 一个关于不支持尾递归消除(TRE)的评论立即引发了一些关于Python不这么做的遗憾的评论,包括其他人尝试“证明”TRE可以被添加到Python的最近博客条目的链接容易。 所以让我来捍卫我的立场(这是我不想TRE的语言)。 如果你想要一个简短的答案,这是简单的。 这是一个长的答案:

CPython不会,也可能不会支持基于Guido关于这个主题的报表的尾部调用优化。 我听说过,由于它如何修改堆栈跟踪,使得调试变得更加困难。

尝试大小的实验macropy TCO实现。

除了优化尾递归之外,您还可以通过以下方式手动设置递归深度:

 import sys sys.setrecursionlimit(5500000) print("recursion limit:%d " % (sys.getrecursionlimit()))