I'm implementing a Trampoline in Python, in order to write recursive functions with stack safety (since CPython does not feature TCO). It looks like this:
from typing import Generic, TypeVar
from abc import ABC, abstractmethod
A = TypeVar('A', covariant=True)
class Trampoline(Generic[A], ABC):
"""
Base class for Trampolines. Useful for writing stack safe-safe
recursive functions.
"""
@abstractmethod
def _resume(self) -> 'Trampoline[A]':
"""
Let this trampoline resume the interpreter loop
"""
pass
@abstractmethod
def _handle_cont(
self, cont: Callable[[A], 'Trampoline[B]']
) -> 'Trampoline[B]':
"""
Handle continuation function passed to `and_then`
"""
pass
@property
def _is_done(self) -> bool:
return isinstance(self, Done)
def and_then(self, f: Callable[[A], 'Trampoline[B]']) -> 'Trampoline[B]':
"""
Apply ``f`` to the value wrapped by this trampoline.
Args:
f: function to apply the value in this trampoline
Return:
Result of applying ``f`` to the value wrapped by \
this trampoline
"""
return AndThen(self, f)
def map(self, f: Callable[[A], B]) -> 'Trampoline[B]':
"""
Map ``f`` over the value wrapped by this trampoline.
Args:
f: function to wrap over this trampoline
Return:
new trampoline wrapping the result of ``f``
"""
return self.and_then(lambda a: Done(f(a)))
def run(self) -> A:
"""
Interpret a structure of trampolines to produce a result
Return:
result of intepreting this structure of \
trampolines
"""
trampoline = self
while not trampoline._is_done:
trampoline = trampoline._resume()
return cast(Done[A], trampoline).a
class Done(Trampoline[A]):
"""
Represents the result of a recursive computation.
"""
a: A
def _resume(self) -> Trampoline[A]:
return self
def _handle_cont(self,
cont: Callable[[A], Trampoline[B]]) -> Trampoline[B]:
return cont(self.a)
class Call(Trampoline[A]):
"""
Represents a recursive call.
"""
thunk: Callable[[], Trampoline[A]]
def _handle_cont(self,
cont: Callable[[A], Trampoline[B]]) -> Trampoline[B]:
return self.thunk().and_then(cont) # type: ignore
def _resume(self) -> Trampoline[A]:
return self.thunk() # type: ignore
class AndThen(Generic[A, B], Trampoline[B]):
"""
Represents monadic bind for trampolines as a class to avoid
deep recursive calls to ``Trampoline.run`` during interpretation.
"""
sub: Trampoline[A]
cont: Callable[[A], Trampoline[B]]
def _handle_cont(self,
cont: Callable[[B], Trampoline[C]]) -> Trampoline[C]:
return self.sub.and_then(self.cont).and_then(cont) # type: ignore
def _resume(self) -> Trampoline[B]:
return self.sub._handle_cont(self.cont) # type: ignore
def and_then( # type: ignore
self, f: Callable[[A], Trampoline[B]]
) -> Trampoline[B]:
return AndThen(
self.sub,
lambda x: Call(lambda: self.cont(x).and_then(f)) # type: ignore
)
Now, I need a monadic sequence operator. My initial take looked like this:
from typing import Iterable
from functools import reduce
def sequence(iterable: Iterable[Trampoline[A]]) -> Trampoline[Iterable[A]]:
def combine(result: Trampoline[Iterable[A]], ta: Trampoline[A]) -> Trampoline[Iterable[A]]:
return result.and_then(lambda as_: ta.map(lambda a: as_ + (a,)))
return reduce(combine, iterable, Done(()))
That works, but the overhead of all the function calls resulting from reducing a long list of trampolines in this way absolutely kills performance.
So instead I tried this:
def sequence(iterable: Iterable[Trampoline[A]]) -> Trampoline[Iterable[A]]:
def thunk() -> Trampoline[Iterable[A]]:
return Done(tuple([t.run() for t in iterable]))
return Call(thunk)
Now, my gut feeling is that the second solution of sequence isn't stack safe because it call's run, which means that run will be calling run during interpretation (through Call.thunk but non the less). However, I can't seem to produce a stack overflow no matter how I mix and match.
For example, I thought this should do it:
t, *ts = [sequence(Done(v) for v in range(2)) for _ in range(10000)]
def combine(t1, t2):
return t1.and_then(lambda _: t2)
final = reduce(combine, ts, t)
final.run() # My gut feeling says this should overflow the stack, but it doesn't
I've tried countless other examples, but no stack overflow. My gut feeling remains that this shouldn't work.
I need someone to convince me that trampolining the interpreter loop in this way is actually stack safe, or show me an example where it overflows the stack