Why can't I mock the decorator of the function?

30 views Asked by At

I'm trying to write a unit test. I have a decorator that checks whether the argument passed to the function is actually an argument of the expected type. @check_types_method

Its usage looks like this:

class Calculator:

    @check_types_method
    def call(self, a: MegaNumber, b: int) -> int:
        return a.value + b
    

If the variable a is not a MegaNumber, an exception will be thrown. It works great, but it creates problems in tests. For example, I need to replace the MegaNumber mock with an object, and if I do this, I get an error in the test.

TypeError: The argument 'a' must be of type <class 'mega_number.MegaNumber'>, received: MagicMock

For the test, you need to lock the decorator so that it does not do anything. And skipped Mock objects. However, I do not know how to do this and I do not understand what the reason is. My best attempt looks like this:

import unittest
from unittest.mock import MagicMock, patch
from calculator import Calculator


@patch('calculator.check_types_method', lambda x: x)
class TestCalculator(unittest.TestCase):

    def test_call(self):
        mock_mega_number = MagicMock()
        mock_mega_number.value = 10
        calculator = Calculator()
        result = calculator.call(mock_mega_number, 5)

        self.assertEqual(result, 15, "The call method should return the sum of MegaNumber's "
                                     "value and the integer provided.")

but the decorator loads before the patch is applied - and I can't beat it. Can you tell me what to do?

the code for playback: check_types_method.py

from inspect import signature
from functools import wraps
from typing import get_origin, get_args, Union


def check_types_method(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        sig = signature(func)
        bound = sig.bind(*args, **kwargs)
        bound.apply_defaults()

        for name, value in bound.arguments.items():
            expected_type = sig.parameters[name].annotation
            if expected_type is not sig.empty:
                if not is_of_generic_type(value, expected_type):
                    raise TypeError(
                        f"The argument '{name}' must be of type {expected_type}, received: {type(value).__name__}")
        return func(*args, **kwargs)

    def is_of_generic_type(obj, generic_type):
        origin_type = get_origin(generic_type)
        arg_types = get_args(generic_type)
        if origin_type is Union:
            return any(is_of_generic_type(obj, arg) for arg in arg_types)
        if origin_type is None:
            # Non-generic types
            return isinstance(obj, generic_type)
        if not isinstance(obj, origin_type):
            return False
        if origin_type in (list, set):
            element_type = arg_types[0]
            return all(is_of_generic_type(item, element_type) for item in obj)
        elif origin_type is dict:
            key_type, value_type = arg_types
            return all(is_of_generic_type(k, key_type) and is_of_generic_type(v, value_type) for k, v in obj.items())
        elif origin_type is tuple:
            if len(arg_types) == 2 and arg_types[1] is Ellipsis:
                return all(isinstance(item, arg_types[0]) for item in obj)
            else:
                return len(obj) == len(arg_types) and all(
                    is_of_generic_type(item, t) for item, t in zip(obj, arg_types))
        return False

    return wrapper

mega_number.py

class MegaNumber:

    def __init__(self):
        self.value = 10

calculator.py

from check_types_method import check_types_method
from mega_number import MegaNumber


class Calculator:

    @check_types_method
    def call(self, a: MegaNumber, b: int) -> int:
        return a.value + b

test_calculator.py

import unittest
from unittest.mock import MagicMock, patch
from calculator import Calculator


@patch('calculator.check_types_method', lambda x: x)
class TestCalculator(unittest.TestCase):

    def test_call(self):
        mock_mega_number = MagicMock()
        mock_mega_number.value = 10
        calculator = Calculator()
        result = calculator.call(mock_mega_number, 5)

        self.assertEqual(result, 15, "The call method should return the sum of MegaNumber's "
                                     "value and the integer provided.")
0

There are 0 answers