Save all intermediate variables in a function, should the function fail

543 views Asked by At

I find myself frequently running into this sort of problem. I have a function like

def compute(input):
    result = two_hour_computation(input)
    result = post_processing(result)
    return result

and post_processing(result) fails. Now the obvious thing to do is to change the function to

import pickle

def compute(input):
    result = two_hour_computation(input)
    pickle.dump(result, open('intermediate_result.pickle', 'wb'))
    result = post_processing(result)
    return result

but I don't usually remember to write all my functions that way. What I wish I had was a decorator like:

@return_intermediate_results_if_something_goes_wrong
def compute(input):
    result = two_hour_computation(input)
    result = post_processing(result)
    return result

Does something like that exist? I can't find it on google.

7

There are 7 answers

1
Daniil Fajnberg On

The "outside" of a function has no access to the state of local variables inside the function at runtime whatsoever. So this cannot be solved with a decorator.

In any case, I would argue that the responsibility for catching errors and saving valuable intermediary results should be done explicitly by the programmer. If you "forget" to do that, it must have not been that important to you.

That being said, situations like "do X in case either A, B, or C raises an exception" are a typical use case for context managers. You can write your own context manager that acts as a bucket for your intermediary result (in place of a variable) and performs some save action in case an exception exits it.

Something like this:

from __future__ import annotations
from types import TracebackType
from typing import Generic, Optional, TypeVar

T = TypeVar("T")

class Saver(Generic[T]):
    def __init__(self, initial_value: Optional[T] = None) -> None:
        self._value = initial_value

    def __enter__(self) -> Saver[T]:
        return self

    def __exit__(
        self,
        exc_type: Optional[type[BaseException]],
        exc_val: Optional[BaseException],
        exc_tb: Optional[TracebackType],
    ) -> None:
        if exc_type is not None:
            self.save()

    def save(self) -> None:
        print(f"saved {self.value}!")

    @property
    def value(self) -> T:
        if self._value is None:
            raise RuntimeError
        return self._value

    @value.setter
    def value(self, value: T) -> None:
        self._value = value

Obviously, instead of print(f"saved {self.value}!") inside save you would do something like this:

        with open('intermediate_result.pickle', 'wb') as f:
            pickle.dump(self.value, f)

Now all you need to remember is to wrap those actions in a with-statement and assign intermediary results to the value property of your context manager. To demonstrate:

def x_times_2(x: float) -> float:
    return x * 2


def one_over_x_minus_2(x: float) -> float:
    return 1 / (x - 2)


def main() -> None:
    with Saver(1.) as s:
        s.value = x_times_2(s.value)
        s.value = one_over_x_minus_2(s.value)
    print(s.value)


if __name__ == "__main__":
    main()

The output:

saved 2.0!
Traceback (most recent call last):
  [...]
    return 1 / (x - 2)
           ~~^~~~~~~~~
ZeroDivisionError: float division by zero

As you can see, the intermediary computed value 2.0 was "saved", even though the next function raised an exception.

It is worth noting that in this example, the context manager calls save only if an exception was encountered, not if the context is exited "peacefully". If you wanted, you could make this unconditional of course.

This may be not as convenient as just slapping a decorator onto a function, but it gets the job done. And IMO the fact that you have to still consciously wrap your important actions in this context is a good thing because it teaches you to pay special attention to these things.

This is the typical approach of implementing things like database transactions in Python by the way (e.g. in SQLAlchemy).

PS

To be fair, I should probably qualify my initial statement a bit. You could of course just use non-local state in your function, even though that is generally discouraged for good reason. In super simple terms, if in your example result was a global variable (and you stated global result inside the function), this could in fact be solved by a decorator. But I would not recommend that approach because global state is an anti-pattern. (And it would still require you to remember to use whatever global variable you designated for that job every time.)

1
Wombatz On

I'm not saying this is a good idea, but it is possible to read the local variables of a function, after it raised an exception:

try:
    my_func_that_raises(...)
except Exception as e:
    traceback = e.__traceback__
    function_frame = traceback.tb_next.tb_frame
    all_local_variables_until_crash = function_frame.f_locals

You can of course wrap this in a decorator

from collections.abc import Callable
from typing import Any, ParamSpec, TypeVar

P = ParamSpec('P')
R = TypeVar('R')


def save_by_inspect(f: Callable[P, R]) -> Callable[P, R]:
    def save(*args: P.args, **kwargs: P.kwargs) -> Any:
        try:
            return f(*args, **kwargs)  # call the function
        except Exception as e:  # oh no, it crashed
            tb = e.__traceback__
            assert tb is not None  # you can implement a different strategy when the traceback is None
            function_tb = tb.tb_next
            assert function_tb is not None
            print(function_tb.tb_frame.f_locals)  # save the local variables
            raise

    return save

This implementation just prints the local variables (Mapping[str, Any]).

Usage is like this:

@save_by_inspect
def foo(crash: bool) -> int:
    expensive = 1 + 2 + 3
    if crash:
        raise RuntimeError("i crashed")

    return expensive + 123


assert foo(crash=False) == 129

foo(crash=True)

The assert succeeds and the second call prints the local variables ({'crash': True, 'expensive': 6}) and then raises the RuntimeError again.

But there is an alternative. If you can remember to put a decorator onto your function, you can also add a minimally intrusive safety feature. The idea is to yield every result, that is important and then return the final result in the end. So your function would look like this:

R = TypeVar('R')  # part of the library
Saved = Generator[Any, None, R]  # part of the library

@save_by_yield  # part of the library
def foo2(crash: bool) -> Saved[int]:  # user code
    x = 1
    yield x  # this is saved on crash
    if crash:
        raise RuntimeError("i crashed")
    return 123  # the final result

The decorator then can run the function (not line by line, as OP suggested, but yield by yield) and gather the yielded results:

from collections.abc import Callable, Generator
from typing import Any, ParamSpec, TypeVar, cast

P = ParamSpec('P')
R = TypeVar('R')
Saved = Generator[Any, None, R]


def save_by_yield(f: Callable[P, Saved[R]]) -> Callable[P, R]:
    def save(*args: P.args, **kwargs: P.kwargs) -> R:
        generator = f(*args, **kwargs)
        save_this = []

        while True:
            try:
                save_this.append(next(generator))  # get an expensive result
            except StopIteration as stop:  # function is done
                return cast(R, stop.value)  # StopIteration is not generic :(
            except Exception:  # function crashed
                print(save_this)
                raise
    return save

Whenever the generator yields, we store the result. When it raises an Exception the stored results are saved (print) and when it completes (StopIteration) the result is just returned.

Note: The decorator is typed correctly. I.e. reveal_type(foo(...)) is int and when you forget to yield a result, mypy will complain with this

Incompatible return value type (got "int", expected "Generator[Any, None, int]")

Not pretty, but it's something.

Note2: I omitted @functools.wraps to shorten the code

0
SrPanda On

It can be done with a decorator, but the decorator should be on the underlying function, mainly because it's much more simple. Assuming that you only want to reuse computations, a cache system should be ideal.

The way i implemented the cache is quite straightforward, the decorator @cache.result gets the call signature (the function name with md5 of the arguments) and whatever the function returns, if the function completes, the result is written to disk as a file; @cache.with_key('key') it's the same but global, any function that gets decorated with it and with the same key will return the same cache value; in both cases the decorators add no extra code or complexity to the functions use.

import os
import io
import pickle
import hashlib

class Cache:

    cache_dir = os.path.join(
        os.path.dirname(__file__), 'cache'
    )

    def __init__(self, log=False):
        self.log_enabled = log

        if not os.path.exists(Cache.cache_dir):
            os.mkdir(Cache.cache_dir)

    def __log(self, text):
        if self.log_enabled:
            print(text)

    def _call_to_key(self, func, args, kwargs):
        key = str(func.__name__)
        buff = io.BytesIO()

        for e in [args, kwargs]:
            pickle.dump(e, buff)

            md5_str = hashlib.md5(
                buff.getvalue(), 
                usedforsecurity=False
            ).hexdigest()
            key += f'-{md5_str}'

            buff.seek(0)
            buff.truncate(0)

        buff.close()

        return key

    def __cache_or_invoke(self, key, func, args, kwargs):
        cache_key = key or self._call_to_key(func, args, kwargs)
        cache_file = os.path.join(cache.cache_dir, cache_key)

        # Every key is a file name in the folder
        if cache_key in os.listdir(Cache.cache_dir):
            self.__log(f'Cache hit  {cache_key}')

            with open(cache_file, 'rb') as file:
                return pickle.load(file)

        else:
            self.__log(f'Cache miss {cache_key}')

            ret = func(*args, **kwargs)

            with open(cache_file, 'wb') as file:
                pickle.dump(ret, file)

            return ret

    def with_key(self, key):
        def decorator(func):
            def wrapper(*args, **kwargs):
                return self.__cache_or_invoke(key, func, args, kwargs);
            return wrapper
        return decorator

    def result(self, func):
        def wrapper(*args, **kwargs):
            return self.__cache_or_invoke(None, func, args, kwargs);
        return wrapper

    def clear(self, key=None):
        if os.path.exists(Cache.cache_dir):
            for file in os.listdir(Cache.cache_dir):

                if key is not None:
                    if file != key:
                        continue            
                    if file.split('-', 1)[0] != key:
                        continue

                os.remove(
                    os.path.join(Cache.cache_dir, file)
                )
                self.__log(f'Cache cleared {file}')
cache = Cache(log=False)

class Rational:
    def __init__(self, den, num):
        self.den = den
        self.num = num

    def __str__(self):
        return f'{self.den}/{self.num}'


@cache.result
def half(rat):
    rat.num *= 2
    return rat

@cache.with_key('float')
def to_float(rat):
    return rat.den / rat.num


if __name__ == '__main__':
    f1 = Rational(1, 3)
    f2 = Rational(5, 2)

    # If not cleared the next [to_float]
    # will return the cache
    cache.clear(key='float')

    half1 = half(f1)
    print(half1)
    print(to_float(half1))

    half2 = half(f2)
    print(half2)
    # Wrong value, cache is set 
    # from the previous call
    print(to_float(half2))

    # cache.clear()

As a side note, there is no need to use cache.<func> as a decorator, it can also be used as a regular function.

cache.result(half)(half2)
cache.with_key('float')(to_float)(half2)
cache.with_key('float')(lambda _: None)(half2)
1
CoderWithAGoodName On

you can create a custom decorator in Python to automatically save intermediate results when something goes wrong. Here's an example of how you can implement such a decorator:

import pickle

def return_intermediate_results_if_something_goes_wrong(func):
    def wrapper(*args, **kwargs):
        try:
            result = func(*args, **kwargs)
            return result
        except Exception as e:
            intermediate_file = 'intermediate_result.pickle'
            print(f"An error occurred: {e}. Saving intermediate result to {intermediate_file}.")
            pickle.dump(result, open(intermediate_file, 'wb'))
            raise  # Re-raise the exception

    return wrapper

# You can use the decorator like this:
@return_intermediate_results_if_something_goes_wrong
def compute(input):
    result = two_hour_computation(input)
    result = post_processing(result)
    return result

With this decorator, if an exception is raised within your compute function, it will catch the exception, save the intermediate result to a pickle file, and then re-raise the exception so that you're aware of the error.

0
blhsing On

The values of all local variables when an exception is raised can be found in each frame of the call stack stored with the traceback object that comes with the exception object.

The traceback.TracebackException class offers a readable way to output local variable values from an exception object when constructed with the capture_locals option enabled. Since you are interested in only the last function causing the exception, we would output only the content of the last frame with an index of -1:

import traceback as tb

def compute(input):
    result = input - 1
    result = 1 / result
    return result

try:
    compute(1)
except Exception as e:
    print(
        tb.TracebackException.from_exception(e, capture_locals=True)
        .stack.format()[-1]
    )

This outputs:

  File "compute.py", line 5, in compute
    result = 1 / result
             ~~^~~~~~~~
    input = 1
    result = 0

It then becomes trivial to create a decorator that does so for the wrapped function:

import traceback as tb

def print_intermediate_results_if_something_goes_wrong(func):
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except Exception as e:
            print(
                tb.TracebackException.from_exception(e, capture_locals=True)
                .stack.format()[-1]
            )
    return wrapper

@print_intermediate_results_if_something_goes_wrong
def compute(input):
    result = input - 1
    result = 1 / result
    return result

compute(1)

Demo: here

If you need this decorator applied to a lot of functions, you might as well override the default exception handler sys.excepthook so you can forget about applying a decorator:

import sys
import traceback as tb

def print_intermediate_results_if_something_goes_wrong(*args):
    print(
        tb.TracebackException(*args, capture_locals=True)
        .stack.format()[-1]
    )
sys.excepthook = print_intermediate_results_if_something_goes_wrong

Demo: here

Or if you would like the entire stack dumped instead for a better understanding of the context:

import sys
import traceback as tb

def print_intermediate_results_if_something_goes_wrong(*args):
    print(*tb.TracebackException(*args, capture_locals=True).format(), sep='\n')
sys.excepthook = print_intermediate_results_if_something_goes_wrong

Demo: here

0
Chad Brewbaker On
def return_all_locals_if_something_goes_wrong_to_var(saved_locals_var):
def decorator(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except Exception as e:
            # Retrieve the frame where the exception was raised
            tb = sys.exc_info()[2]  # Get current traceback
            frame = tb.tb_next.tb_frame  # Go one level back to get the frame of the original function
            
            # Save all local variables from that frame to the provided variable
            saved_locals_var[0] = frame.f_locals
            # Throw a SIGTRAP here to attach w GDB
            raise e  # Re-raise the exception
    return wrapper
return decorator

Cursed but it seemed to work for the examples I tried.

0
mattiatantardini On

I think an easier approach is to explicitly save your results after the computation ends, whether or not the following processing fails. In this case, you can design a utility decorator to be used on the computational heavy function.

The decorator can be of the following type:

def save_pickle(func):
    def wrapper(*args, **kwargs):
        results = func(*args, **kwargs)
        with open("results.pkl", "rb") as f:
            pickle.dump(results, f)
        return results
    return wrapper

and use it like this whenever you need:

@save_pickle
def two_hours_computations(input):
    # function definition here

You may generalize the decorator function to take as argument the path where to store the saved pickle.